[Mlir-commits] [mlir] [mlir][memref] Fix index delinearization for CollapseShapeOp folding (PR #68833)
Felix Schneider
llvmlistbot at llvm.org
Wed Oct 11 12:44:42 PDT 2023
https://github.com/ubfx created https://github.com/llvm/llvm-project/pull/68833
The `resolveSourceIndicesCollapseShape` method is used to compute indices into the source `MemRef` of a `CollapseShapeOp` from the collapsed indices. This method didn't check for dynamic sizes of the source shape which led to a crash.
Fix https://github.com/llvm/llvm-project/issues/68483
>From be718224124b13431cd9e0c4077a647f9684597a Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Wed, 11 Oct 2023 19:43:01 +0000
Subject: [PATCH] [mlir][memref] Fix index delinearization for CollapseShapeOp
folding
The `resolveSourceIndicesCollapseShape` method is used to compute
indices into the source `MemRef` of a `CollapseShapeOp` from the
collapsed indices. This method didn't check for dynamic sizes of
the source shape which led to a crash.
Fix https://github.com/llvm/llvm-project/issues/68483
---
.../MemRef/Transforms/FoldMemRefAliasOps.cpp | 12 +++++++++---
.../Dialect/MemRef/fold-memref-alias-ops.mlir | 16 ++++++++++++++++
2 files changed, 25 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 7f8322bd5f6f445..9899c357daeeeb4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -128,10 +128,16 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
dynamicIndices.push_back(indices[cnt++]);
int64_t groupSize = groups.size();
- // Calculate suffix product for all collapse op source dimension sizes.
- SmallVector<int64_t> sizes(groupSize);
- for (int64_t i = 0; i < groupSize; ++i)
+ // Calculate suffix product for all collapse op source dimension sizes
+ // except the most major one of each group.
+ // We allow the most major source dimension to be dynamic but enforce all
+ // others to be known statically.
+ SmallVector<int64_t> sizes(groupSize, 1);
+ for (int64_t i = 1; i < groupSize; ++i) {
sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
+ if (sizes[i] == ShapedType::kDynamic)
+ return failure();
+ }
SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
// Derive the index values along all dimensions of the source corresponding
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 7d6a2e57d958b35..b94711352745978 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -317,6 +317,22 @@ func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg
// -----
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
+// CHECK-LABEL: @fold_dynamic_size_collapse_shape_with_affine_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x6x32xf32> into memref<?x32xf32>
+ %1 = affine.load %0[%arg1, %arg2] : memref<?x32xf32>
+ return %1 : f32
+}
+// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
+// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<2x6x32xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1, s2] -> (s0 * 6 + s1 * 3 + s2)>
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {
More information about the Mlir-commits
mailing list