[Mlir-commits] [mlir] f6897c3 - [mlir][MemRef] Bail out for unsupported cases in FoldMemRefAliasOps pass
Hanhan Wang
llvmlistbot at llvm.org
Fri Aug 11 14:53:21 PDT 2023
Author: Hanhan Wang
Date: 2023-08-11T14:52:53-07:00
New Revision: f6897c37a2b2ca4037c67dca891062431d6eb869
URL: https://github.com/llvm/llvm-project/commit/f6897c37a2b2ca4037c67dca891062431d6eb869
DIFF: https://github.com/llvm/llvm-project/commit/f6897c37a2b2ca4037c67dca891062431d6eb869.diff
LOG: [mlir][MemRef] Bail out for unsupported cases in FoldMemRefAliasOps pass
The pass uses `computeSuffixProduct` method which only allows static
shapes. This revision adds an early-exit for dynamic cases to avoid
crash.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D157668
Added:
Modified:
mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 1fee97a0dd7470..7f8322bd5f6f44 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -63,6 +63,12 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
memref::ExpandShapeOp expandShapeOp,
ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
+ // The below implementation uses computeSuffixProduct method, which only
+ // allows int64_t values (i.e., static shape). Bail out if it has dynamic
+ // shapes.
+ if (!expandShapeOp.getResultType().hasStaticShape())
+ return failure();
+
MLIRContext *ctx = rewriter.getContext();
for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
assert(!groups.empty() && "association indices groups cannot be empty");
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index acfddc366df165..7d6a2e57d958b3 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -331,6 +331,19 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
// -----
+// CHECK-LABEL: fold_dynamic_subview_with_memref_load_store_expand_shape
+func.func @fold_dynamic_subview_with_memref_load_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index) -> f32 {
+ %c0 = arith.constant 0 : index
+ %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
+ %0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
+ return %0 : f32
+}
+// CHECK: %[[EXPAND_SHAPE:.+]] = memref.expand_shape {{.+}} : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
+// CHECK: %[[LOAD:.+]] = memref.load %[[EXPAND_SHAPE]]
+// CHECK: return %[[LOAD]]
+
+// -----
+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
More information about the Mlir-commits
mailing list