[Mlir-commits] [mlir] 889b67c - [mlir] [memref] add more checks to the memref.reinterpret_cast (#112669)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 25 17:07:54 PDT 2024


Author: donald chen
Date: 2024-10-26T08:07:51+08:00
New Revision: 889b67c9d30e3024a1317431d66c22599f6c2011

URL: https://github.com/llvm/llvm-project/commit/889b67c9d30e3024a1317431d66c22599f6c2011
DIFF: https://github.com/llvm/llvm-project/commit/889b67c9d30e3024a1317431d66c22599f6c2011.diff

LOG: [mlir] [memref] add more checks to the memref.reinterpret_cast (#112669)

Operation memref.reinterpret_cast was accept input like:

%out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10],
strides: [1]
         : memref<?xf32> to memref<10xf32>

A problem arises: while lowering, the true offset of %out is %offset,
but its data type indicates an offset of 0. Permitting this
inconsistency can result in incorrect outcomes, as certain pass might
erroneously extract the offset from the data type of %out.

This patch fixes this by enforcing that the return value's data type
aligns
with the input parameter.

Added: 
    

Modified: 
    mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
    mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
    mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
    mlir/test/Dialect/GPU/decompose-memrefs.mlir
    mlir/test/Dialect/MemRef/expand-ops.mlir
    mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
    mlir/test/Dialect/MemRef/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp
index 2b2d10a7733ece..004d73a77e5359 100644
--- a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp
@@ -29,6 +29,17 @@ namespace mlir {
 
 using namespace mlir;
 
+static MemRefType inferCastResultType(Value source, OpFoldResult offset) {
+  auto sourceType = cast<BaseMemRefType>(source.getType());
+  SmallVector<int64_t> staticOffsets;
+  SmallVector<Value> dynamicOffsets;
+  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
+  auto stridedLayout =
+      StridedLayoutAttr::get(source.getContext(), staticOffsets.front(), {});
+  return MemRefType::get({}, sourceType.getElementType(), stridedLayout,
+                         sourceType.getMemorySpace());
+}
+
 static void setInsertionPointToStart(OpBuilder &builder, Value val) {
   if (auto *parentOp = val.getDefiningOp()) {
     builder.setInsertionPointAfter(parentOp);
@@ -98,7 +109,7 @@ static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source,
   SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult(offsets);
   auto &&[base, offset, ignore] =
       getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp);
-  auto retType = cast<MemRefType>(base.getType());
+  MemRefType retType = inferCastResultType(base, offset);
   return rewriter.create<memref::ReinterpretCastOp>(loc, retType, base, offset,
                                                     std::nullopt, std::nullopt);
 }

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index d579a27359dfa0..2219505c9b802f 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1892,11 +1892,12 @@ LogicalResult ReinterpretCastOp::verify() {
   // Match sizes in result memref type and in static_sizes attribute.
   for (auto [idx, resultSize, expectedSize] :
        llvm::enumerate(resultType.getShape(), getStaticSizes())) {
-    if (!ShapedType::isDynamic(resultSize) &&
-        !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
+    if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
       return emitError("expected result type with size = ")
-             << expectedSize << " instead of " << resultSize
-             << " in dim = " << idx;
+             << (ShapedType::isDynamic(expectedSize)
+                     ? std::string("dynamic")
+                     : std::to_string(expectedSize))
+             << " instead of " << resultSize << " in dim = " << idx;
   }
 
   // Match offset and strides in static_offset and static_strides attributes. If
@@ -1910,20 +1911,22 @@ LogicalResult ReinterpretCastOp::verify() {
 
   // Match offset in result memref type and in static_offsets attribute.
   int64_t expectedOffset = getStaticOffsets().front();
-  if (!ShapedType::isDynamic(resultOffset) &&
-      !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
+  if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset)
     return emitError("expected result type with offset = ")
-           << expectedOffset << " instead of " << resultOffset;
+           << (ShapedType::isDynamic(expectedOffset)
+                   ? std::string("dynamic")
+                   : std::to_string(expectedOffset))
+           << " instead of " << resultOffset;
 
   // Match strides in result memref type and in static_strides attribute.
   for (auto [idx, resultStride, expectedStride] :
        llvm::enumerate(resultStrides, getStaticStrides())) {
-    if (!ShapedType::isDynamic(resultStride) &&
-        !ShapedType::isDynamic(expectedStride) &&
-        resultStride != expectedStride)
+    if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride)
       return emitError("expected result type with stride = ")
-             << expectedStride << " instead of " << resultStride
-             << " in dim = " << idx;
+             << (ShapedType::isDynamic(expectedStride)
+                     ? std::string("dynamic")
+                     : std::to_string(expectedStride))
+             << " instead of " << resultStride << " in dim = " << idx;
   }
 
   return success();

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index faba12f5bf82f8..83683c7e617bf8 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -89,7 +89,8 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
     strides.resize(rank);
 
     Location loc = op.getLoc();
-    Value stride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    Value stride = nullptr;
+    int64_t staticStride = 1;
     for (int i = rank - 1; i >= 0; --i) {
       Value size;
       // Load dynamic sizes from the shape input, use constants for static dims.
@@ -105,9 +106,22 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
         size = rewriter.create<arith::ConstantOp>(loc, sizeAttr);
         sizes[i] = sizeAttr;
       }
-      strides[i] = stride;
-      if (i > 0)
-        stride = rewriter.create<arith::MulIOp>(loc, stride, size);
+      if (stride)
+        strides[i] = stride;
+      else
+        strides[i] = rewriter.getIndexAttr(staticStride);
+
+      if (i > 0) {
+        if (stride) {
+          stride = rewriter.create<arith::MulIOp>(loc, stride, size);
+        } else if (op.getType().isDynamicDim(i)) {
+          stride = rewriter.create<arith::MulIOp>(
+              loc, rewriter.create<arith::ConstantIndexOp>(loc, staticStride),
+              size);
+        } else {
+          staticStride *= op.getType().getDimSize(i);
+        }
+      }
     }
     rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
         op, op.getType(), op.getSource(), /*offset=*/rewriter.getIndexAttr(0),

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index a2049ba4a4924d..087d1fcc2b23ae 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -507,6 +507,8 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
 
   SmallVector<OpFoldResult> groupStrides;
   ArrayRef<int64_t> srcShape = sourceType.getShape();
+
+  OpFoldResult lastValidStride = nullptr;
   for (int64_t currentDim : reassocGroup) {
     // Skip size-of-1 dimensions, since right now their strides may be
     // meaningless.
@@ -517,11 +519,11 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
       continue;
 
     int64_t currentStride = strides[currentDim];
-    groupStrides.push_back(ShapedType::isDynamic(currentStride)
-                               ? origStrides[currentDim]
-                               : builder.getIndexAttr(currentStride));
+    lastValidStride = ShapedType::isDynamic(currentStride)
+                          ? origStrides[currentDim]
+                          : builder.getIndexAttr(currentStride);
   }
-  if (groupStrides.empty()) {
+  if (!lastValidStride) {
     // We're dealing with a 1x1x...x1 shape. The stride is meaningless,
     // but we still have to make the type system happy.
     MemRefType collapsedType = collapseShape.getResultType();
@@ -543,12 +545,7 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
     return {builder.getIndexAttr(finalStride)};
   }
 
-  // For the general case, we just want the minimum stride
-  // since the collapsed dimensions are contiguous.
-  auto minMap = AffineMap::getMultiDimIdentityMap(groupStrides.size(),
-                                                  builder.getContext());
-  return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap,
-                                      groupStrides)};
+  return {lastValidStride};
 }
 
 /// From `reshape_like(memref, subSizes, subStrides))` compute

diff  --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index 55b1bc9c545a85..ec5ceae57ccb33 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -425,8 +425,6 @@ func.func @collapse_shape_dynamic_with_non_identity_layout(
 // CHECK:           %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[STRIDE0_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0]] : i64 to index
-// CHECK:           %[[STRIDE0:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0_TO_IDX]] : index to i64
 // CHECK:           %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]]  : i64
 // CHECK:           %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
 // CHECK:           %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
@@ -548,23 +546,19 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32>
 // CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
 // CHECK:           %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[STRIDE1:.*]] = llvm.extractvalue %[[MEM]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
 // CHECK:           %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE2]], %[[C2]]  : i64
 // CHECK:           %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
 // CHECK:           %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
-// CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
-// CHECK:           %[[MIN_STRIDE1:.*]] = llvm.intr.smin(%[[STRIDE1]], %[[C1]]) : (i64, i64) -> i64
-// CHECK:           %[[MIN_STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1]] : i64 to index
-// CHECK:           %[[MIN_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1_TO_IDX]] : index to i64
 // CHECK:           %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
 // CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[DESC5:.*]] = llvm.insertvalue %[[FINAL_SIZE1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[DESC6:.*]] = llvm.insertvalue %[[MIN_STRIDE1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC6:.*]] = llvm.insertvalue %[[C1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC6]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<1x?xf32>
 // CHECK:           return %[[RES]] : memref<1x?xf32>
 // CHECK:         }

diff  --git a/mlir/test/Dialect/GPU/decompose-memrefs.mlir b/mlir/test/Dialect/GPU/decompose-memrefs.mlir
index 56fc9a66b7ace7..1a192219484517 100644
--- a/mlir/test/Dialect/GPU/decompose-memrefs.mlir
+++ b/mlir/test/Dialect/GPU/decompose-memrefs.mlir
@@ -7,8 +7,8 @@
 //       CHECK:  gpu.launch
 //  CHECK-SAME:  threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
 //       CHECK:  %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
-//       CHECK:  %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
-//       CHECK:  memref.store %[[VAL]], %[[PTR]][] : memref<f32>
+//       CHECK:  %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>>
+//       CHECK:  memref.store %[[VAL]], %[[PTR]][] : memref<f32, strided<[], offset: ?>>
 func.func @decompose_store(%arg0 : f32, %arg1 : memref<?x?x?xf32>) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -33,8 +33,8 @@ func.func @decompose_store(%arg0 : f32, %arg1 : memref<?x?x?xf32>) {
 //       CHECK:  gpu.launch
 //  CHECK-SAME:  threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
 //       CHECK:  %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[OFFSET]], %[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]], %[[STRIDES]]#2]
-//       CHECK:  %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
-//       CHECK:  memref.store %[[VAL]], %[[PTR]][] : memref<f32>
+//       CHECK:  %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>>
+//       CHECK:  memref.store %[[VAL]], %[[PTR]][] : memref<f32, strided<[], offset: ?>>
 func.func @decompose_store_strided(%arg0 : f32, %arg1 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -59,8 +59,8 @@ func.func @decompose_store_strided(%arg0 : f32, %arg1 : memref<?x?x?xf32, stride
 //       CHECK:  gpu.launch
 //  CHECK-SAME:  threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
 //       CHECK:  %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
-//       CHECK:  %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
-//       CHECK:  %[[RES:.*]] = memref.load %[[PTR]][] : memref<f32>
+//       CHECK:  %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>>
+//       CHECK:  %[[RES:.*]] = memref.load %[[PTR]][] : memref<f32, strided<[], offset: ?>>
 //       CHECK:  "test.test"(%[[RES]]) : (f32) -> ()
 func.func @decompose_load(%arg0 : memref<?x?x?xf32>) {
   %c0 = arith.constant 0 : index

diff  --git a/mlir/test/Dialect/MemRef/expand-ops.mlir b/mlir/test/Dialect/MemRef/expand-ops.mlir
index f958a92b751a4a..65932b5814a668 100644
--- a/mlir/test/Dialect/MemRef/expand-ops.mlir
+++ b/mlir/test/Dialect/MemRef/expand-ops.mlir
@@ -52,14 +52,13 @@ func.func @memref_reshape(%input: memref<*xf32>,
 // CHECK-SAME: [[SRC:%.*]]: memref<*xf32>,
 // CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref<?x?x8xf32> {
 
-// CHECK: [[C1:%.*]] = arith.constant 1 : index
 // CHECK: [[C8:%.*]] = arith.constant 8 : index
-// CHECK: [[STRIDE_1:%.*]] = arith.muli [[C1]], [[C8]] : index
-
-// CHECK: [[C1_:%.*]] = arith.constant 1 : index
-// CHECK: [[DIM_1:%.*]] = memref.load [[SHAPE]]{{\[}}[[C1_]]] : memref<3xi32>
+// CHECK: [[C1:%.*]] = arith.constant 1 : index
+// CHECK: [[DIM_1:%.*]] = memref.load [[SHAPE]]{{\[}}[[C1]]] : memref<3xi32>
 // CHECK: [[SIZE_1:%.*]] = arith.index_cast [[DIM_1]] : i32 to index
-// CHECK: [[STRIDE_0:%.*]] = arith.muli [[STRIDE_1]], [[SIZE_1]] : index
+
+// CHECK: [[C8_:%.*]] = arith.constant 8 : index
+// CHECK: [[STRIDE_0:%.*]] = arith.muli [[C8_]], [[SIZE_1]] : index
 
 // CHECK: [[C0:%.*]] = arith.constant 0 : index
 // CHECK: [[DIM_0:%.*]] = memref.load [[SHAPE]]{{\[}}[[C0]]] : memref<3xi32>
@@ -67,5 +66,5 @@ func.func @memref_reshape(%input: memref<*xf32>,
 
 // CHECK: [[RESULT:%.*]] = memref.reinterpret_cast [[SRC]]
 // CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], 8],
-// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[STRIDE_1]], [[C1]]]
+// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], 8, 1]
 // CHECK-SAME: : memref<*xf32> to memref<?x?x8xf32>

diff  --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 8aac802ba10ae9..647731db439c08 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -931,19 +931,15 @@ func.func @extract_aligned_pointer_as_index_of_unranked_source(%arg0: memref<*xf
 //          = min(7, 1)
 //          = 1
 //
-//   CHECK-DAG: #[[$STRIDE0_MIN_MAP:.*]] = affine_map<()[s0] -> (s0)>
-//   CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
-//   CHECK-DAG: #[[$STRIDE1_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1, 42)>
+//       CHECK: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
 // CHECK-LABEL: func @simplify_collapse(
 //  CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
 //
 //       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32>
 //
-//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MIN_MAP]]()[%[[STRIDES]]#0]
-//   CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
-//   CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.min #[[$STRIDE1_MIN_MAP]]()[%[[STRIDES]]#1, %[[STRIDES]]#2]
+//       CHECK: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
 //
-//       CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], 1]
+//       CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[STRIDES]]#0, 42, 1]
 func.func @simplify_collapse(%arg : memref<?x?x4x?x6x7xi32>)
   -> memref<?x?x42xi32> {
 
@@ -1046,15 +1042,12 @@ func.func @simplify_collapse_with_dim_of_size1_and_non_1_stride
 //           We just return the first dynamic one for this group.
 //
 //
-//   CHECK-DAG: #[[$STRIDE0_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1)>
 // CHECK-LABEL: func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride(
 //  CHECK-SAME: %[[ARG:.*]]: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2]
 //
 //       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:5, %[[STRIDES:.*]]:5 = memref.extract_strided_metadata %[[ARG]] : memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>
 //
-//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MIN_MAP]]()[%[[STRIDES]]#0, %[[STRIDES]]#1]
-//
-//       CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [6, 1], strides: [%[[DYN_STRIDE0]], %[[STRIDES]]#2]
+//       CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [6, 1], strides: [%[[STRIDES]]#1, %[[STRIDES]]#2]
 func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
     (%arg0: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>)
     -> memref<6x1xi32, strided<[?, ?], offset: ?>> {
@@ -1083,8 +1076,7 @@ func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
 // Stride 2 = origStride5
 //          = 1
 //
-//   CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
-//   CHECK-DAG: #[[$STRIDE0_MAP:.*]] = affine_map<()[s0] -> (s0)>
+//       CHECK: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
 // CHECK-LABEL: func @extract_strided_metadata_of_collapse(
 //  CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
 //
@@ -1094,10 +1086,9 @@ func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
 //
 //   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32>
 //
-//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MAP]]()[%[[STRIDES]]#0]
 //   CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
 //
-//       CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[DYN_STRIDE0]], %[[C42]], %[[C1]]
+//       CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[STRIDES]]#0, %[[C42]], %[[C1]]
 func.func @extract_strided_metadata_of_collapse(%arg : memref<?x?x4x?x6x7xi32>)
   -> (memref<i32>, index,
       index, index, index,

diff  --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 0f533cb95a0ca9..51c4781c9022b2 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -217,6 +217,15 @@ func.func @memref_reinterpret_cast_no_map_but_offset(%in: memref<?xf32>) {
 
 // -----
 
+func.func @memref_reinterpret_cast_offset_mismatch_dynamic(%in: memref<?xf32>, %offset : index) {
+  // expected-error @+1 {{expected result type with offset = dynamic instead of 0}}
+  %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
+         : memref<?xf32> to memref<10xf32>
+  return
+}
+
+// -----
+
 func.func @memref_reinterpret_cast_no_map_but_stride(%in: memref<?xf32>) {
   // expected-error @+1 {{expected result type with stride = 10 instead of 1 in dim = 0}}
   %out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [10]


        


More information about the Mlir-commits mailing list