[Mlir-commits] [mlir] 889b67c - [mlir] [memref] add more checks to the memref.reinterpret_cast (#112669)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 25 17:07:54 PDT 2024
Author: donald chen
Date: 2024-10-26T08:07:51+08:00
New Revision: 889b67c9d30e3024a1317431d66c22599f6c2011
URL: https://github.com/llvm/llvm-project/commit/889b67c9d30e3024a1317431d66c22599f6c2011
DIFF: https://github.com/llvm/llvm-project/commit/889b67c9d30e3024a1317431d66c22599f6c2011.diff
LOG: [mlir] [memref] add more checks to the memref.reinterpret_cast (#112669)
Operation memref.reinterpret_cast was accept input like:
%out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10],
strides: [1]
: memref<?xf32> to memref<10xf32>
A problem arises: while lowering, the true offset of %out is %offset,
but its data type indicates an offset of 0. Permitting this
inconsistency can result in incorrect outcomes, as certain pass might
erroneously extract the offset from the data type of %out.
This patch fixes this by enforcing that the return value's data type
aligns
with the input parameter.
Added:
Modified:
mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
mlir/test/Dialect/GPU/decompose-memrefs.mlir
mlir/test/Dialect/MemRef/expand-ops.mlir
mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
mlir/test/Dialect/MemRef/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp
index 2b2d10a7733ece..004d73a77e5359 100644
--- a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp
@@ -29,6 +29,17 @@ namespace mlir {
using namespace mlir;
+static MemRefType inferCastResultType(Value source, OpFoldResult offset) {
+ auto sourceType = cast<BaseMemRefType>(source.getType());
+ SmallVector<int64_t> staticOffsets;
+ SmallVector<Value> dynamicOffsets;
+ dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
+ auto stridedLayout =
+ StridedLayoutAttr::get(source.getContext(), staticOffsets.front(), {});
+ return MemRefType::get({}, sourceType.getElementType(), stridedLayout,
+ sourceType.getMemorySpace());
+}
+
static void setInsertionPointToStart(OpBuilder &builder, Value val) {
if (auto *parentOp = val.getDefiningOp()) {
builder.setInsertionPointAfter(parentOp);
@@ -98,7 +109,7 @@ static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source,
SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult(offsets);
auto &&[base, offset, ignore] =
getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp);
- auto retType = cast<MemRefType>(base.getType());
+ MemRefType retType = inferCastResultType(base, offset);
return rewriter.create<memref::ReinterpretCastOp>(loc, retType, base, offset,
std::nullopt, std::nullopt);
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index d579a27359dfa0..2219505c9b802f 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1892,11 +1892,12 @@ LogicalResult ReinterpretCastOp::verify() {
// Match sizes in result memref type and in static_sizes attribute.
for (auto [idx, resultSize, expectedSize] :
llvm::enumerate(resultType.getShape(), getStaticSizes())) {
- if (!ShapedType::isDynamic(resultSize) &&
- !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
+ if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
return emitError("expected result type with size = ")
- << expectedSize << " instead of " << resultSize
- << " in dim = " << idx;
+ << (ShapedType::isDynamic(expectedSize)
+ ? std::string("dynamic")
+ : std::to_string(expectedSize))
+ << " instead of " << resultSize << " in dim = " << idx;
}
// Match offset and strides in static_offset and static_strides attributes. If
@@ -1910,20 +1911,22 @@ LogicalResult ReinterpretCastOp::verify() {
// Match offset in result memref type and in static_offsets attribute.
int64_t expectedOffset = getStaticOffsets().front();
- if (!ShapedType::isDynamic(resultOffset) &&
- !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
+ if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset)
return emitError("expected result type with offset = ")
- << expectedOffset << " instead of " << resultOffset;
+ << (ShapedType::isDynamic(expectedOffset)
+ ? std::string("dynamic")
+ : std::to_string(expectedOffset))
+ << " instead of " << resultOffset;
// Match strides in result memref type and in static_strides attribute.
for (auto [idx, resultStride, expectedStride] :
llvm::enumerate(resultStrides, getStaticStrides())) {
- if (!ShapedType::isDynamic(resultStride) &&
- !ShapedType::isDynamic(expectedStride) &&
- resultStride != expectedStride)
+ if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride)
return emitError("expected result type with stride = ")
- << expectedStride << " instead of " << resultStride
- << " in dim = " << idx;
+ << (ShapedType::isDynamic(expectedStride)
+ ? std::string("dynamic")
+ : std::to_string(expectedStride))
+ << " instead of " << resultStride << " in dim = " << idx;
}
return success();
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index faba12f5bf82f8..83683c7e617bf8 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -89,7 +89,8 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
strides.resize(rank);
Location loc = op.getLoc();
- Value stride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value stride = nullptr;
+ int64_t staticStride = 1;
for (int i = rank - 1; i >= 0; --i) {
Value size;
// Load dynamic sizes from the shape input, use constants for static dims.
@@ -105,9 +106,22 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
size = rewriter.create<arith::ConstantOp>(loc, sizeAttr);
sizes[i] = sizeAttr;
}
- strides[i] = stride;
- if (i > 0)
- stride = rewriter.create<arith::MulIOp>(loc, stride, size);
+ if (stride)
+ strides[i] = stride;
+ else
+ strides[i] = rewriter.getIndexAttr(staticStride);
+
+ if (i > 0) {
+ if (stride) {
+ stride = rewriter.create<arith::MulIOp>(loc, stride, size);
+ } else if (op.getType().isDynamicDim(i)) {
+ stride = rewriter.create<arith::MulIOp>(
+ loc, rewriter.create<arith::ConstantIndexOp>(loc, staticStride),
+ size);
+ } else {
+ staticStride *= op.getType().getDimSize(i);
+ }
+ }
}
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
op, op.getType(), op.getSource(), /*offset=*/rewriter.getIndexAttr(0),
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index a2049ba4a4924d..087d1fcc2b23ae 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -507,6 +507,8 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
SmallVector<OpFoldResult> groupStrides;
ArrayRef<int64_t> srcShape = sourceType.getShape();
+
+ OpFoldResult lastValidStride = nullptr;
for (int64_t currentDim : reassocGroup) {
// Skip size-of-1 dimensions, since right now their strides may be
// meaningless.
@@ -517,11 +519,11 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
continue;
int64_t currentStride = strides[currentDim];
- groupStrides.push_back(ShapedType::isDynamic(currentStride)
- ? origStrides[currentDim]
- : builder.getIndexAttr(currentStride));
+ lastValidStride = ShapedType::isDynamic(currentStride)
+ ? origStrides[currentDim]
+ : builder.getIndexAttr(currentStride);
}
- if (groupStrides.empty()) {
+ if (!lastValidStride) {
// We're dealing with a 1x1x...x1 shape. The stride is meaningless,
// but we still have to make the type system happy.
MemRefType collapsedType = collapseShape.getResultType();
@@ -543,12 +545,7 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
return {builder.getIndexAttr(finalStride)};
}
- // For the general case, we just want the minimum stride
- // since the collapsed dimensions are contiguous.
- auto minMap = AffineMap::getMultiDimIdentityMap(groupStrides.size(),
- builder.getContext());
- return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap,
- groupStrides)};
+ return {lastValidStride};
}
/// From `reshape_like(memref, subSizes, subStrides))` compute
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index 55b1bc9c545a85..ec5ceae57ccb33 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -425,8 +425,6 @@ func.func @collapse_shape_dynamic_with_non_identity_layout(
// CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: %[[STRIDE0_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0]] : i64 to index
-// CHECK: %[[STRIDE0:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0_TO_IDX]] : index to i64
// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]] : i64
// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
@@ -548,23 +546,19 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32>
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEM]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE2]], %[[C2]] : i64
// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
-// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
-// CHECK: %[[MIN_STRIDE1:.*]] = llvm.intr.smin(%[[STRIDE1]], %[[C1]]) : (i64, i64) -> i64
-// CHECK: %[[MIN_STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1]] : i64 to index
-// CHECK: %[[MIN_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1_TO_IDX]] : index to i64
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[FINAL_SIZE1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[MIN_STRIDE1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[C1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC6]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<1x?xf32>
// CHECK: return %[[RES]] : memref<1x?xf32>
// CHECK: }
diff --git a/mlir/test/Dialect/GPU/decompose-memrefs.mlir b/mlir/test/Dialect/GPU/decompose-memrefs.mlir
index 56fc9a66b7ace7..1a192219484517 100644
--- a/mlir/test/Dialect/GPU/decompose-memrefs.mlir
+++ b/mlir/test/Dialect/GPU/decompose-memrefs.mlir
@@ -7,8 +7,8 @@
// CHECK: gpu.launch
// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
-// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
-// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref<f32>
+// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>>
+// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref<f32, strided<[], offset: ?>>
func.func @decompose_store(%arg0 : f32, %arg1 : memref<?x?x?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -33,8 +33,8 @@ func.func @decompose_store(%arg0 : f32, %arg1 : memref<?x?x?xf32>) {
// CHECK: gpu.launch
// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[OFFSET]], %[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]], %[[STRIDES]]#2]
-// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
-// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref<f32>
+// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>>
+// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref<f32, strided<[], offset: ?>>
func.func @decompose_store_strided(%arg0 : f32, %arg1 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -59,8 +59,8 @@ func.func @decompose_store_strided(%arg0 : f32, %arg1 : memref<?x?x?xf32, stride
// CHECK: gpu.launch
// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
-// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
-// CHECK: %[[RES:.*]] = memref.load %[[PTR]][] : memref<f32>
+// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>>
+// CHECK: %[[RES:.*]] = memref.load %[[PTR]][] : memref<f32, strided<[], offset: ?>>
// CHECK: "test.test"(%[[RES]]) : (f32) -> ()
func.func @decompose_load(%arg0 : memref<?x?x?xf32>) {
%c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/MemRef/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir
index f958a92b751a4a..65932b5814a668 100644
--- a/mlir/test/Dialect/MemRef/expand-ops.mlir
+++ b/mlir/test/Dialect/MemRef/expand-ops.mlir
@@ -52,14 +52,13 @@ func.func @memref_reshape(%input: memref<*xf32>,
// CHECK-SAME: [[SRC:%.*]]: memref<*xf32>,
// CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref<?x?x8xf32> {
-// CHECK: [[C1:%.*]] = arith.constant 1 : index
// CHECK: [[C8:%.*]] = arith.constant 8 : index
-// CHECK: [[STRIDE_1:%.*]] = arith.muli [[C1]], [[C8]] : index
-
-// CHECK: [[C1_:%.*]] = arith.constant 1 : index
-// CHECK: [[DIM_1:%.*]] = memref.load [[SHAPE]]{{\[}}[[C1_]]] : memref<3xi32>
+// CHECK: [[C1:%.*]] = arith.constant 1 : index
+// CHECK: [[DIM_1:%.*]] = memref.load [[SHAPE]]{{\[}}[[C1]]] : memref<3xi32>
// CHECK: [[SIZE_1:%.*]] = arith.index_cast [[DIM_1]] : i32 to index
-// CHECK: [[STRIDE_0:%.*]] = arith.muli [[STRIDE_1]], [[SIZE_1]] : index
+
+// CHECK: [[C8_:%.*]] = arith.constant 8 : index
+// CHECK: [[STRIDE_0:%.*]] = arith.muli [[C8_]], [[SIZE_1]] : index
// CHECK: [[C0:%.*]] = arith.constant 0 : index
// CHECK: [[DIM_0:%.*]] = memref.load [[SHAPE]]{{\[}}[[C0]]] : memref<3xi32>
@@ -67,5 +66,5 @@ func.func @memref_reshape(%input: memref<*xf32>,
// CHECK: [[RESULT:%.*]] = memref.reinterpret_cast [[SRC]]
// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], 8],
-// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[STRIDE_1]], [[C1]]]
+// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], 8, 1]
// CHECK-SAME: : memref<*xf32> to memref<?x?x8xf32>
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 8aac802ba10ae9..647731db439c08 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -931,19 +931,15 @@ func.func @extract_aligned_pointer_as_index_of_unranked_source(%arg0: memref<*xf
// = min(7, 1)
// = 1
//
-// CHECK-DAG: #[[$STRIDE0_MIN_MAP:.*]] = affine_map<()[s0] -> (s0)>
-// CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
-// CHECK-DAG: #[[$STRIDE1_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1, 42)>
+// CHECK: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
// CHECK-LABEL: func @simplify_collapse(
// CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
//
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32>
//
-// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MIN_MAP]]()[%[[STRIDES]]#0]
-// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
-// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.min #[[$STRIDE1_MIN_MAP]]()[%[[STRIDES]]#1, %[[STRIDES]]#2]
+// CHECK: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
//
-// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], 1]
+// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[STRIDES]]#0, 42, 1]
func.func @simplify_collapse(%arg : memref<?x?x4x?x6x7xi32>)
-> memref<?x?x42xi32> {
@@ -1046,15 +1042,12 @@ func.func @simplify_collapse_with_dim_of_size1_and_non_1_stride
// We just return the first dynamic one for this group.
//
//
-// CHECK-DAG: #[[$STRIDE0_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1)>
// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride(
// CHECK-SAME: %[[ARG:.*]]: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2]
//
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:5, %[[STRIDES:.*]]:5 = memref.extract_strided_metadata %[[ARG]] : memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>
//
-// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MIN_MAP]]()[%[[STRIDES]]#0, %[[STRIDES]]#1]
-//
-// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [6, 1], strides: [%[[DYN_STRIDE0]], %[[STRIDES]]#2]
+// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [6, 1], strides: [%[[STRIDES]]#1, %[[STRIDES]]#2]
func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
(%arg0: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>)
-> memref<6x1xi32, strided<[?, ?], offset: ?>> {
@@ -1083,8 +1076,7 @@ func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
// Stride 2 = origStride5
// = 1
//
-// CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
-// CHECK-DAG: #[[$STRIDE0_MAP:.*]] = affine_map<()[s0] -> (s0)>
+// CHECK: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
// CHECK-LABEL: func @extract_strided_metadata_of_collapse(
// CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
//
@@ -1094,10 +1086,9 @@ func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
//
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32>
//
-// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MAP]]()[%[[STRIDES]]#0]
// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
//
-// CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[DYN_STRIDE0]], %[[C42]], %[[C1]]
+// CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[STRIDES]]#0, %[[C42]], %[[C1]]
func.func @extract_strided_metadata_of_collapse(%arg : memref<?x?x4x?x6x7xi32>)
-> (memref<i32>, index,
index, index, index,
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 0f533cb95a0ca9..51c4781c9022b2 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -217,6 +217,15 @@ func.func @memref_reinterpret_cast_no_map_but_offset(%in: memref<?xf32>) {
// -----
+func.func @memref_reinterpret_cast_offset_mismatch_dynamic(%in: memref<?xf32>, %offset : index) {
+ // expected-error @+1 {{expected result type with offset = dynamic instead of 0}}
+ %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
+ : memref<?xf32> to memref<10xf32>
+ return
+}
+
+// -----
+
func.func @memref_reinterpret_cast_no_map_but_stride(%in: memref<?xf32>) {
// expected-error @+1 {{expected result type with stride = 10 instead of 1 in dim = 0}}
%out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [10]
More information about the Mlir-commits
mailing list