[Mlir-commits] [mlir] 64f9984 - [mlir][ExpandStridedMetadata] Handle collapse_shape of dim of size 1 gracefully

Quentin Colombet llvmlistbot at llvm.org
Wed Dec 7 23:41:49 PST 2022


Author: Quentin Colombet
Date: 2022-12-08T07:32:01Z
New Revision: 64f99842a6c03fdb57349fcaed3f4821ef612ed1

URL: https://github.com/llvm/llvm-project/commit/64f99842a6c03fdb57349fcaed3f4821ef612ed1
DIFF: https://github.com/llvm/llvm-project/commit/64f99842a6c03fdb57349fcaed3f4821ef612ed1.diff

LOG: [mlir][ExpandStridedMetadata] Handle collapse_shape of dim of size 1 gracefully

Collapsing dimensions of size 1 with random strides (a.k.a.
non-contiguous w.r.t. collapsed dimensions) is a grey area that we'd
like to clean-up. (See https://reviews.llvm.org/D136483#3909856)

That said, the implementation in `memref-to-llvm` currently skips
dimensions of size 1 when computing the stride of a group.

While longer term we may want to clean that up, for now matches this
behavior, at least in the static case.

For the dynamic case, for this patch we stick to `min(group strides)`.
However, if we want to handle the dynamic cases correctly while allowing
non-truly-contiguous dynamic size of 1, we would need to `if-then-else`
every dynamic size. In other words `min(stride_i, for all i in group and
dim_i != 1)`.

I didn't implement that in this patch at the moment since
`memref-to-llvm` is technically broken in the general case for this. (It
currently would only produce something sensible for row major tensors.)

Differential Revision: https://reviews.llvm.org/D139329

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
    mlir/test/Dialect/MemRef/expand-strided-metadata.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index cac5490a409cb..4613d603ee5ad 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
@@ -403,15 +404,50 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
 
   auto [strides, offset] = getStridesAndOffset(sourceType);
 
-  SmallVector<OpFoldResult> collapsedStride;
-  int64_t innerMostDimForGroup = reassocGroup.back();
-  int64_t innerMostStrideForGroup = strides[innerMostDimForGroup];
-  collapsedStride.push_back(
-      ShapedType::isDynamic(innerMostStrideForGroup)
-          ? origStrides[innerMostDimForGroup]
-          : builder.getIndexAttr(innerMostStrideForGroup));
+  SmallVector<OpFoldResult> groupStrides;
+  ArrayRef<int64_t> srcShape = sourceType.getShape();
+  for (int64_t currentDim : reassocGroup) {
+    // Skip size-of-1 dimensions, since right now their strides may be
+    // meaningless.
+    // FIXME: size-of-1 dimensions shouldn't be used in collapse shape, unless
+    // they are truly contiguous. When they are truly contiguous, we shouldn't
+    // need to skip them.
+    if (srcShape[currentDim] == 1)
+      continue;
+
+    int64_t currentStride = strides[currentDim];
+    groupStrides.push_back(ShapedType::isDynamic(currentStride)
+                               ? origStrides[currentDim]
+                               : builder.getIndexAttr(currentStride));
+  }
+  if (groupStrides.empty()) {
+    // We're dealing with a 1x1x...x1 shape. The stride is meaningless,
+    // but we still have to make the type system happy.
+    MemRefType collapsedType = collapseShape.getResultType();
+    auto [collapsedStrides, collapsedOffset] =
+        getStridesAndOffset(collapsedType);
+    int64_t finalStride = collapsedStrides[groupId];
+    if (ShapedType::isDynamic(finalStride)) {
+      // Look for a dynamic stride. At this point we don't know which one is
+      // desired, but they are all equally good/bad.
+      for (int64_t currentDim : reassocGroup) {
+        assert(srcShape[currentDim] == 1 &&
+               "We should be dealing with 1x1x...x1");
+
+        if (ShapedType::isDynamic(strides[currentDim]))
+          return {origStrides[currentDim]};
+      }
+      llvm_unreachable("We should have found a dynamic stride");
+    }
+    return {builder.getIndexAttr(finalStride)};
+  }
 
-  return collapsedStride;
+  // For the general case, we just want the minimum stride
+  // since the collapsed dimensions are contiguous.
+  auto minMap = AffineMap::getMultiDimIdentityMap(groupStrides.size(),
+                                                  builder.getContext());
+  return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap,
+                                      groupStrides)};
 }
 /// Replace `baseBuffer, offset, sizes, strides =
 ///              extract_strided_metadata(reshapeLike(memref))`

diff  --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 2b5c90c588d51..0eaf7d186e903 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -907,21 +907,28 @@ func.func @extract_aligned_pointer_as_index(%arg0: memref<f32>) -> index {
 // Size 2 = origSize4 * origSize5
 //        = 6 * 7
 //        = 42
-// Stride 0 = origStride0
-// Stride 1 = origStride3 (orig stride of the inner most dimension)
-//          = 42
-// Stride 2 = origStride5
+// Stride 0 = min(origStride0)
+//          = Right now the folder of affine.min is not smart
+//            enough to just return origStride0
+// Stride 1 = min(origStride1, origStride2, origStride3)
+//          = min(origStride1, origStride2, 42)
+// Stride 2 = min(origStride4, origStride5)
+//          = min(7, 1)
 //          = 1
 //
+//   CHECK-DAG: #[[$STRIDE0_MIN_MAP:.*]] = affine_map<()[s0] -> (s0)>
 //   CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
+//   CHECK-DAG: #[[$STRIDE1_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1, 42)>
 // CHECK-LABEL: func @simplify_collapse(
 //  CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
 //
 //       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32>
 //
-//       CHECK: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
+//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MIN_MAP]]()[%[[STRIDES]]#0]
+//   CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
+//   CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.min #[[$STRIDE1_MIN_MAP]]()[%[[STRIDES]]#1, %[[STRIDES]]#2]
 //
-//       CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[STRIDES]]#0, 42, 1]
+//       CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], 1]
 func.func @simplify_collapse(%arg : memref<?x?x4x?x6x7xi32>)
   -> memref<?x?x42xi32> {
 
@@ -934,6 +941,118 @@ func.func @simplify_collapse(%arg : memref<?x?x4x?x6x7xi32>)
 
 // -----
 
+// Check that we simplify collapse_shape into
+// reinterpret_cast(extract_strided_metadata) + <some math>
+// when there are dimensions of size 1 involved.
+//
+// We transform: 3x1 to [0, 1]
+//
+// The tricky bit here is the strides between dimension 0 and 1
+// are not truly contiguous, but since we dealing with a dimension of size 1
+// this is actually fine (i.e., we are not going to jump around.)
+//
+// As a result the resulting stride needs to ignore the strides of the
+// dimensions of size 1.
+//
+// Size 0 = origSize0 * origSize1
+//        = 3 * 1
+//        = 3
+// Stride 0 = min(origStride_i, for all i in reassocation group and dim_i != 1)
+//          = min(origStride0)
+//          = min(2)
+//          = 2
+//
+// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1(
+//  CHECK-SAME: %[[ARG:.*]]: memref<3x1xf32, strided<[2, 1]>>,
+//
+//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<3x1xf32, strided<[2, 1]>>
+//
+//
+//       CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [3], strides: [2]
+func.func @simplify_collapse_with_dim_of_size1(%arg0: memref<3x1xf32, strided<[2,1]>>, %arg1: memref<3xf32>) {
+
+  %collapse_shape = memref.collapse_shape %arg0 [[0, 1]] :
+    memref<3x1xf32, strided<[2, 1]>> into memref<3xf32, strided<[2]>>
+
+  memref.copy %collapse_shape, %arg1 : memref<3xf32, strided<[2]>> to memref<3xf32>
+
+  return
+}
+
+
+// -----
+
+// Check that we simplify collapse_shape with an edge case group of 1x1x...x1.
+//
+// The tricky bit here is also the resulting stride is meaningless, we still
+// have to please the type system.
+//
+// In this case, we're collapsing two strides of respectively 2 and 1 and the
+// resulting type wants a stride of 2.
+//
+// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1_and_non_1_stride(
+//  CHECK-SAME: %[[ARG:.*]]: memref<1x1xi32, strided<[2, 1]
+//
+//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<1x1xi32, strided<[2, 1], offset: ?>>
+//
+//       CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [1], strides: [2]
+func.func @simplify_collapse_with_dim_of_size1_and_non_1_stride
+    (%arg0: memref<1x1xi32, strided<[2, 1], offset: ?>>)
+    -> memref<1xi32, strided<[2], offset: ?>> {
+
+  %collapse_shape = memref.collapse_shape %arg0 [[0, 1]] :
+    memref<1x1xi32, strided<[2, 1], offset: ?>>
+    into memref<1xi32, strided<[2], offset: ?>>
+
+  return %collapse_shape : memref<1xi32, strided<[2], offset: ?>>
+}
+
+// -----
+
+// Check that we simplify collapse_shape with an edge case group of 1x1x...x1.
+// We also have a couple of collapsed dimensions before the 1x1x...x1 group
+// to make sure we properly index into the dynamic strides based on the
+// group ID.
+//
+// The tricky bit in this test is that the 1x1x...x1 group stride is dynamic
+// so we have to propagate one of the dynamic dimension for this group.
+//
+// For this test we have:
+// Size0 = origSize0 * origSize1
+//       = 2 * 3
+//       = 6
+// Size1 = origSize2 * origSize3 * origSize4
+//       = 1 * 1 * 1
+//       = 1
+//
+// Stride0 = min(origStride0, origStride1)
+// Stride1 = we actually don't know, this is dynamic but we don't know
+//           which one to pick.
+//           We just return the first dynamic one for this group.
+//
+//
+//   CHECK-DAG: #[[$STRIDE0_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1)>
+// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride(
+//  CHECK-SAME: %[[ARG:.*]]: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2]
+//
+//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:5, %[[STRIDES:.*]]:5 = memref.extract_strided_metadata %[[ARG]] : memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>
+//
+//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MIN_MAP]]()[%[[STRIDES]]#0, %[[STRIDES]]#1]
+//
+//       CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [6, 1], strides: [%[[DYN_STRIDE0]], %[[STRIDES]]#2]
+func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
+    (%arg0: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>)
+    -> memref<6x1xi32, strided<[?, ?], offset: ?>> {
+
+  %collapse_shape = memref.collapse_shape %arg0 [[0, 1], [2, 3, 4]] :
+    memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>
+    into memref<6x1xi32, strided<[?, ?], offset: ?>>
+
+  return %collapse_shape : memref<6x1xi32, strided<[?, ?], offset: ?>>
+}
+
+// -----
+
 // Check that we simplify extract_strided_metadata of collapse_shape.
 //
 // We transform: ?x?x4x?x6x7xi32 to [0][1,2,3][4,5]
@@ -950,6 +1069,7 @@ func.func @simplify_collapse(%arg : memref<?x?x4x?x6x7xi32>)
 //          = 1
 //
 //   CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
+//   CHECK-DAG: #[[$STRIDE0_MAP:.*]] = affine_map<()[s0] -> (s0)>
 // CHECK-LABEL: func @extract_strided_metadata_of_collapse(
 //  CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
 //
@@ -959,9 +1079,10 @@ func.func @simplify_collapse(%arg : memref<?x?x4x?x6x7xi32>)
 //
 //   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32>
 //
+//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MAP]]()[%[[STRIDES]]#0]
 //   CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
 //
-//       CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[STRIDES]]#0, %[[C42]], %[[C1]]
+//       CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[DYN_STRIDE0]], %[[C42]], %[[C1]]
 func.func @extract_strided_metadata_of_collapse(%arg : memref<?x?x4x?x6x7xi32>)
   -> (memref<i32>, index,
       index, index, index,


        


More information about the Mlir-commits mailing list