[Mlir-commits] [mlir] 6d45284 - [mlir][memref] Add better support for identity layouts in memref.collapse_shape canonicalizer

Stephan Herhut llvmlistbot at llvm.org
Thu Jan 20 06:32:06 PST 2022


Author: Stephan Herhut
Date: 2022-01-20T15:31:43+01:00
New Revision: 6d45284618f08fa28dc515cab96fa573c4c4479e

URL: https://github.com/llvm/llvm-project/commit/6d45284618f08fa28dc515cab96fa573c4c4479e
DIFF: https://github.com/llvm/llvm-project/commit/6d45284618f08fa28dc515cab96fa573c4c4479e.diff

LOG: [mlir][memref] Add better support for identity layouts in memref.collapse_shape canonicalizer

When computing the new type of a collapse_shape operation, we need to at least
take into account whether the type has an identity layout, in which case we can
easily support dynamic strides. Otherwise, the canonicalizer creates invalid
IR.

Longer term, both the verifier and the canoncializer need to be extended to
support the general case.

Differential Revision: https://reviews.llvm.org/D117772

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 55499f63295f0..211af3045b9d8 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1334,6 +1334,7 @@ computeReshapeCollapsedType(MemRefType type,
   AffineExpr offset;
   SmallVector<AffineExpr, 4> strides;
   auto status = getStridesAndOffset(type, strides, offset);
+  auto isIdentityLayout = type.getLayout().isIdentity();
   (void)status;
   assert(succeeded(status) && "expected strided memref");
 
@@ -1350,12 +1351,19 @@ computeReshapeCollapsedType(MemRefType type,
     unsigned dim = m.getNumResults();
     int64_t size = 1;
     AffineExpr stride = strides[currentDim + dim - 1];
-    if (!isReshapableDimBand(currentDim, dim, sizes, strides)) {
+    if (isIdentityLayout ||
+        isReshapableDimBand(currentDim, dim, sizes, strides)) {
+      for (unsigned d = 0; d < dim; ++d) {
+        int64_t currentSize = sizes[currentDim + d];
+        if (ShapedType::isDynamic(currentSize)) {
+          size = ShapedType::kDynamicSize;
+          break;
+        }
+        size *= currentSize;
+      }
+    } else {
       size = ShapedType::kDynamicSize;
       stride = AffineExpr();
-    } else {
-      for (unsigned d = 0; d < dim; ++d)
-        size *= sizes[currentDim + d];
     }
     newSizes.push_back(size);
     newStrides.push_back(stride);

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index bd7a8dd830a80..58083437ca47e 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -406,6 +406,8 @@ func @collapse_after_memref_cast_type_change(%arg0 : memref<?x512x1x1xf32>) -> m
   return %collapsed : memref<?x?xf32>
 }
 
+// -----
+
 // CHECK-LABEL:   func @collapse_after_memref_cast(
 // CHECK-SAME:      %[[INPUT:.*]]: memref<?x512x1x?xf32>) -> memref<?x?xf32> {
 // CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]
@@ -419,6 +421,21 @@ func @collapse_after_memref_cast(%arg0 : memref<?x512x1x?xf32>) -> memref<?x?xf3
 
 // -----
 
+// CHECK-LABEL:   func @collapse_after_memref_cast_type_change_dynamic(
+// CHECK-SAME:      %[[INPUT:.*]]: memref<1x1x1x?xi64>) -> memref<?x?xi64> {
+// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]
+// CHECK_SAME:        {{\[\[}}0, 1, 2], [3]] : memref<1x1x1x?xi64> into memref<1x?xi64>
+// CHECK:           %[[DYNAMIC:.*]] = memref.cast %[[COLLAPSED]] :
+// CHECK-SAME:         memref<1x?xi64> to memref<?x?xi64>
+// CHECK:           return %[[DYNAMIC]] : memref<?x?xi64>
+func @collapse_after_memref_cast_type_change_dynamic(%arg0: memref<1x1x1x?xi64>) -> memref<?x?xi64> {
+  %casted = memref.cast %arg0 : memref<1x1x1x?xi64> to memref<1x1x?x?xi64>
+  %collapsed = memref.collapse_shape %casted [[0, 1, 2], [3]] : memref<1x1x?x?xi64> into memref<?x?xi64>
+  return %collapsed : memref<?x?xi64>
+}
+
+// -----
+
 func @reduced_memref(%arg0: memref<2x5x7x1xf32>, %arg1 :index)
     -> memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>> {
   %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list