[Mlir-commits] [mlir] [mlir][memref] Fix computeCollapsedLayoutMap for contiguous dynamic dim (PR #136485)
Maya Amrami
llvmlistbot at llvm.org
Sun Apr 20 05:12:33 PDT 2025
https://github.com/amrami created https://github.com/llvm/llvm-project/pull/136485
None
>From 47602a637b48ae3921683dd3f6d48c638395ce75 Mon Sep 17 00:00:00 2001
From: Maya Amrami <mayaam88 at gmail.com>
Date: Sun, 20 Apr 2025 14:43:17 +0300
Subject: [PATCH] [mlir][memref] Fix computeCollapsedLayoutMap for contiguous
dynamic dim
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 19 ++++++++++++++-----
mlir/test/Dialect/MemRef/invalid.mlir | 8 ++++++++
2 files changed, 22 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 6f10a31c15626..45ac4c9d5117e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -23,6 +23,7 @@
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
+#include <algorithm>
using namespace mlir;
using namespace mlir::memref;
@@ -2401,11 +2402,19 @@ computeCollapsedLayoutMap(MemRefType srcType,
if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
resultStrides.push_back(srcStrides[ref.back()]);
} else {
- // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
- // the corresponding stride may have to be skipped. (See above comment.)
- // Therefore, the result stride cannot be statically determined and must
- // be dynamic.
- resultStrides.push_back(ShapedType::kDynamic);
+ bool contiguousSrcDim = srcStrides[ref.back()] == 1;
+ bool dynamicSizeIsPreserved =
+ std::all_of(ref.begin(), ref.end() - 1,
+ [srcShape](int64_t dim) { return srcShape[dim] == 1; });
+ if (contiguousSrcDim && dynamicSizeIsPreserved)
+ resultStrides.push_back(1);
+ else {
+ // Dynamically-sized dims may turn out to be dims of size 1 at runtime,
+ // so the corresponding stride may have to be skipped. (See above
+ // comment.) Therefore, the result stride cannot be statically
+ // determined and must be dynamic.
+ resultStrides.push_back(ShapedType::kDynamic);
+ }
}
}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 34fc4775924e7..a8fcd91fba097 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -502,6 +502,14 @@ func.func @collapse_shape_wrong_collapsed_type(%arg0: memref<?x?x?xf32>) {
// -----
+func.func @collapse_shape_infer_stride_of_dynamic_dim(%arg0: memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1>, %dim : index) -> (memref<?xsi32, strided<[?]>, 1>) {
+ // expected-error @+1 {{expected collapsed type to be 'memref<?xsi32, strided<[1]>, 1>' but found 'memref<?xsi32, strided<[?]>, 1>'}}
+ %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1> into memref<?xsi32, strided<[?]>, 1>
+ return %collapse_shape : memref<?xsi32, strided<[?]>, 1>
+}
+
+// -----
+
func.func @expand_shape_illegal_static_memref
(%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32> {
// expected-error @+1 {{collapsed dim size (20) must equal reassociation group size (40)}}
More information about the Mlir-commits
mailing list