[Mlir-commits] [mlir] [mlir][memref] Fix index delinearization for CollapseShapeOp folding (PR #68833)

Felix Schneider llvmlistbot at llvm.org
Wed Oct 11 12:44:42 PDT 2023


https://github.com/ubfx created https://github.com/llvm/llvm-project/pull/68833

The `resolveSourceIndicesCollapseShape` method is used to compute indices into the source `MemRef` of a `CollapseShapeOp` from the collapsed indices. This method didn't check for dynamic sizes of the source shape which led to a crash.

Fix https://github.com/llvm/llvm-project/issues/68483

>From be718224124b13431cd9e0c4077a647f9684597a Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Wed, 11 Oct 2023 19:43:01 +0000
Subject: [PATCH] [mlir][memref] Fix index delinearization for CollapseShapeOp
 folding

The `resolveSourceIndicesCollapseShape` method is used to compute
indices into the source `MemRef` of a `CollapseShapeOp` from the
collapsed indices. This method didn't check for dynamic sizes of
the source shape which led to a crash.

Fix https://github.com/llvm/llvm-project/issues/68483
---
 .../MemRef/Transforms/FoldMemRefAliasOps.cpp     | 12 +++++++++---
 .../Dialect/MemRef/fold-memref-alias-ops.mlir    | 16 ++++++++++++++++
 2 files changed, 25 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 7f8322bd5f6f445..9899c357daeeeb4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -128,10 +128,16 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
     dynamicIndices.push_back(indices[cnt++]);
     int64_t groupSize = groups.size();
 
-    // Calculate suffix product for all collapse op source dimension sizes.
-    SmallVector<int64_t> sizes(groupSize);
-    for (int64_t i = 0; i < groupSize; ++i)
+    // Calculate suffix product for all collapse op source dimension sizes
+    // except the most major one of each group.
+    // We allow the most major source dimension to be dynamic but enforce all
+    // others to be known statically.
+    SmallVector<int64_t> sizes(groupSize, 1);
+    for (int64_t i = 1; i < groupSize; ++i) {
       sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
+      if (sizes[i] == ShapedType::kDynamic)
+        return failure();
+    }
     SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
 
     // Derive the index values along all dimensions of the source corresponding
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 7d6a2e57d958b35..b94711352745978 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -317,6 +317,22 @@ func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg
 
 // -----
 
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
+// CHECK-LABEL: @fold_dynamic_size_collapse_shape_with_affine_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x6x32xf32> into memref<?x32xf32>
+  %1 = affine.load %0[%arg1, %arg2] : memref<?x32xf32>
+  return %1 : f32
+}
+// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
+// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<2x6x32xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
 // CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1, s2] -> (s0 * 6 + s1 * 3 + s2)>
 // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
 // CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {



More information about the Mlir-commits mailing list