[Mlir-commits] [mlir] [mlir][memref] Improve `memref.subview` type inference (PR #96421)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jun 23 01:50:47 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

The `memref.subview` result type inference (`SubViewOp::inferResultType`) sometimes used to produce a dynamic offset when a static offset is possible.

When a dynamic value (stride, size, etc.) is multiplied with zero, the result is always a "static 0". Based on this, the result type inference implementation can be improved to produce more static type information in memref types.


---

Patch is 30.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/96421.diff


9 Files Affected:

- (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+6) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+2-2) 
- (modified) mlir/test/Dialect/Linalg/promote.mlir (+21-21) 
- (modified) mlir/test/Dialect/Linalg/transform-promotion.mlir (+6-6) 
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+8-9) 
- (modified) mlir/test/Dialect/MemRef/subview.mlir (+2-2) 
- (modified) mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir (+8-8) 
- (modified) mlir/test/Dialect/Vector/vector-transferop-opt.mlir (+5-5) 
- (modified) mlir/test/Transforms/canonicalize.mlir (+5-5) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 594bcf5dbb399..ba4f084d3efd1 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -200,6 +200,12 @@ struct SaturatedInteger {
     return SaturatedInteger{false, other.v + v};
   }
   SaturatedInteger operator*(SaturatedInteger other) {
+    // Multiplication with 0 is always 0.
+    if (!other.saturated && other.v == 0)
+      return SaturatedInteger{false, 0};
+    if (!saturated && v == 0)
+      return SaturatedInteger{false, 0};
+    // Otherwise, if this or the other integer is dynamic, so is the result.
     if (saturated || other.saturated)
       return SaturatedInteger{true, 0};
     return SaturatedInteger{false, other.v * v};
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index d70e6d0b79cd6..f9e8a27973178 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2691,10 +2691,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
   int64_t targetOffset = sourceOffset;
   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
-    auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
+    auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
     targetOffset = (SaturatedInteger::wrap(targetOffset) +
                     SaturatedInteger::wrap(staticOffset) *
-                        SaturatedInteger::wrap(targetStride))
+                        SaturatedInteger::wrap(sourceStride))
                        .asInteger();
   }
 
diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir
index 2d640057df340..00b8c649b82c3 100644
--- a/mlir/test/Dialect/Linalg/promote.mlir
+++ b/mlir/test/Dialect/Linalg/promote.mlir
@@ -42,23 +42,23 @@ func.func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 ///
 //       CHECK:         %[[tmpA:.*]] = memref.alloca() : memref<32xi8>
 //       CHECK:         %[[fullA:.*]] = memref.view %[[tmpA]][{{.*}}][{{.*}}] : memref<32xi8> to memref<?x?xf32>
-//       CHECK:         %[[partialA:.*]] = memref.subview %[[fullA]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+//       CHECK:         %[[partialA:.*]] = memref.subview %[[fullA]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>>
 ///
 //       CHECK:         %[[tmpB:.*]] = memref.alloca() : memref<48xi8>
 //       CHECK:         %[[fullB:.*]] = memref.view %[[tmpB]][{{.*}}][{{.*}}] : memref<48xi8> to memref<?x?xf32>
-//       CHECK:         %[[partialB:.*]] = memref.subview %[[fullB]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+//       CHECK:         %[[partialB:.*]] = memref.subview %[[fullB]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>>
 ///
 //       CHECK:         %[[tmpC:.*]] = memref.alloca() : memref<24xi8>
 //       CHECK:         %[[fullC:.*]] = memref.view %[[tmpC]][{{.*}}][{{.*}}] : memref<24xi8> to memref<?x?xf32>
-//       CHECK:         %[[partialC:.*]] = memref.subview %[[fullC]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+//       CHECK:         %[[partialC:.*]] = memref.subview %[[fullC]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>>
 
-//       CHECK:         linalg.copy ins(%[[vA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
-//       CHECK:         linalg.copy ins(%[[vB]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialB]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
-//       CHECK:         linalg.copy ins(%[[vC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
+//       CHECK:         linalg.copy ins(%[[vA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialA]] : memref<?x?xf32, strided<[?, 1]>>)
+//       CHECK:         linalg.copy ins(%[[vB]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialB]] : memref<?x?xf32, strided<[?, 1]>>)
+//       CHECK:         linalg.copy ins(%[[vC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialC]] : memref<?x?xf32, strided<[?, 1]>>)
 //
 //       CHECK:         linalg.matmul ins(%[[partialA]], %[[partialB]]{{.*}} outs(%[[partialC]]
 //
-//       CHECK:         linalg.copy ins(%[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[vC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
+//       CHECK:         linalg.copy ins(%[[partialC]] : memref<?x?xf32, strided<[?, 1]>>) outs(%[[vC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
 //
 //   CHECK-NOT:         memref.dealloc %[[tmpA]] : memref<32xi8>
 //   CHECK-NOT:         memref.dealloc %[[tmpB]] : memref<48xi8>
@@ -112,23 +112,23 @@ func.func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 ///
 //       CHECK:         %[[tmpA_f64:.*]] = memref.alloc() : memref<64xi8>
 //       CHECK:         %[[fullA_f64:.*]] = memref.view %[[tmpA_f64]][{{.*}}][{{.*}}] : memref<64xi8> to memref<?x?xf64>
-//       CHECK:         %[[partialA_f64:.*]] = memref.subview %[[fullA_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf64> to memref<?x?xf64, strided<[?, 1], offset: ?>>
+//       CHECK:         %[[partialA_f64:.*]] = memref.subview %[[fullA_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf64> to memref<?x?xf64, strided<[?, 1]>>
 ///
 //       CHECK:         %[[tmpB_f64:.*]] = memref.alloc() : memref<96xi8>
 //       CHECK:         %[[fullB_f64:.*]] = memref.view %[[tmpB_f64]][{{.*}}][{{.*}}] : memref<96xi8> to memref<?x?xf64>
-//       CHECK:         %[[partialB_f64:.*]] = memref.subview %[[fullB_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf64> to memref<?x?xf64, strided<[?, 1], offset: ?>>
+//       CHECK:         %[[partialB_f64:.*]] = memref.subview %[[fullB_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf64> to memref<?x?xf64, strided<[?, 1]>>
 ///
 //       CHECK:         %[[tmpC_f64:.*]] = memref.alloc() : memref<48xi8>
 //       CHECK:         %[[fullC_f64:.*]] = memref.view %[[tmpC_f64]][{{.*}}][{{.*}}] : memref<48xi8> to memref<?x?xf64>
-//       CHECK:         %[[partialC_f64:.*]] = memref.subview %[[fullC_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf64> to memref<?x?xf64, strided<[?, 1], offset: ?>>
+//       CHECK:         %[[partialC_f64:.*]] = memref.subview %[[fullC_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf64> to memref<?x?xf64, strided<[?, 1]>>
 
-//       CHECK:         linalg.copy ins(%[[vA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
-//       CHECK:         linalg.copy ins(%[[vB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
-//       CHECK:         linalg.copy ins(%[[vC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
+//       CHECK:         linalg.copy ins(%[[vA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialA_f64]] : memref<?x?xf64, strided<[?, 1]>>)
+//       CHECK:         linalg.copy ins(%[[vB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialB_f64]] : memref<?x?xf64, strided<[?, 1]>>)
+//       CHECK:         linalg.copy ins(%[[vC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialC_f64]] : memref<?x?xf64, strided<[?, 1]>>)
 //
 //       CHECK:         linalg.matmul ins(%[[partialA_f64]], %[[partialB_f64]]{{.*}} outs(%[[partialC_f64]]
 //
-//       CHECK:         linalg.copy ins(%[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[vC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
+//       CHECK:         linalg.copy ins(%[[partialC_f64]] : memref<?x?xf64, strided<[?, 1]>>) outs(%[[vC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
 //
 //       CHECK:         memref.dealloc %[[tmpA_f64]] : memref<64xi8>
 //       CHECK:         memref.dealloc %[[tmpB_f64]] : memref<96xi8>
@@ -318,7 +318,7 @@ func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<3x4xf
   // CHECK:           %[[VAL_21:.*]] = arith.constant 12 : index
   // CHECK:           %[[VAL_22:.*]] = memref.alloc() : memref<48xi8, #gpu.address_space<workgroup>>
   // CHECK:           %[[VAL_23:.*]] = memref.view %[[VAL_22]]{{\[}}%[[VAL_18]]]{{\[}}%[[VAL_12]], %[[VAL_15]]] : memref<48xi8, #gpu.address_space<workgroup>> to memref<?x?xf32, #gpu.address_space<workgroup>>
-  // CHECK:           %[[VAL_24:.*]] = memref.subview %[[VAL_23]][0, 0] {{\[}}%[[VAL_14]], %[[VAL_17]]] [1, 1] : memref<?x?xf32, #gpu.address_space<workgroup>> to memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
+  // CHECK:           %[[VAL_24:.*]] = memref.subview %[[VAL_23]][0, 0] {{\[}}%[[VAL_14]], %[[VAL_17]]] [1, 1] : memref<?x?xf32, #gpu.address_space<workgroup>> to memref<?x?xf32, strided<[?, 1]>, #gpu.address_space<workgroup>>
   // CHECK:           %[[VAL_25:.*]] = arith.constant 0 : index
   // CHECK:           %[[VAL_26:.*]] = arith.constant 4 : index
   // CHECK:           %[[VAL_27:.*]] = arith.constant 1 : index
@@ -337,7 +337,7 @@ func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<3x4xf
   // CHECK:           %[[VAL_40:.*]] = arith.constant 12 : index
   // CHECK:           %[[VAL_41:.*]] = memref.alloc() : memref<48xi8, #gpu.address_space<workgroup>>
   // CHECK:           %[[VAL_42:.*]] = memref.view %[[VAL_41]]{{\[}}%[[VAL_37]]]{{\[}}%[[VAL_31]], %[[VAL_34]]] : memref<48xi8, #gpu.address_space<workgroup>> to memref<?x?xf32, #gpu.address_space<workgroup>>
-  // CHECK:           %[[VAL_43:.*]] = memref.subview %[[VAL_42]][0, 0] {{\[}}%[[VAL_33]], %[[VAL_36]]] [1, 1] : memref<?x?xf32, #gpu.address_space<workgroup>> to memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
+  // CHECK:           %[[VAL_43:.*]] = memref.subview %[[VAL_42]][0, 0] {{\[}}%[[VAL_33]], %[[VAL_36]]] [1, 1] : memref<?x?xf32, #gpu.address_space<workgroup>> to memref<?x?xf32, strided<[?, 1]>, #gpu.address_space<workgroup>>
   // CHECK:           %[[VAL_44:.*]] = arith.constant 0 : index
   // CHECK:           %[[VAL_45:.*]] = arith.constant 4 : index
   // CHECK:           %[[VAL_46:.*]] = arith.constant 1 : index
@@ -356,10 +356,10 @@ func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<3x4xf
   // CHECK:           %[[VAL_59:.*]] = arith.constant 12 : index
   // CHECK:           %[[VAL_60:.*]] = memref.alloc() : memref<48xi8, #gpu.address_space<workgroup>>
   // CHECK:           %[[VAL_61:.*]] = memref.view %[[VAL_60]]{{\[}}%[[VAL_56]]]{{\[}}%[[VAL_50]], %[[VAL_53]]] : memref<48xi8, #gpu.address_space<workgroup>> to memref<?x?xf32, #gpu.address_space<workgroup>>
-  // CHECK:           %[[VAL_62:.*]] = memref.subview %[[VAL_61]][0, 0] {{\[}}%[[VAL_52]], %[[VAL_55]]] [1, 1] : memref<?x?xf32, #gpu.address_space<workgroup>> to memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>
-// CHECK:           linalg.copy ins(%[[VAL_3]] : memref<4x3xf32, strided<[4, 1]>, 1>) outs(%[[VAL_24]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>)
-// CHECK:           linalg.copy ins(%[[VAL_4]] : memref<4x3xf32, strided<[4, 1]>, 1>) outs(%[[VAL_43]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>)
-  // CHECK:           linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"], library_call = ""} ins(%[[VAL_24]], %[[VAL_43]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>, memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>) outs(%[[VAL_62]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>) {
+  // CHECK:           %[[VAL_62:.*]] = memref.subview %[[VAL_61]][0, 0] {{\[}}%[[VAL_52]], %[[VAL_55]]] [1, 1] : memref<?x?xf32, #gpu.address_space<workgroup>> to memref<?x?xf32, strided<[?, 1]>, #gpu.address_space<workgroup>>
+// CHECK:           linalg.copy ins(%[[VAL_3]] : memref<4x3xf32, strided<[4, 1]>, 1>) outs(%[[VAL_24]] : memref<?x?xf32, strided<[?, 1]>, #gpu.address_space<workgroup>>)
+// CHECK:           linalg.copy ins(%[[VAL_4]] : memref<4x3xf32, strided<[4, 1]>, 1>) outs(%[[VAL_43]] : memref<?x?xf32, strided<[?, 1]>, #gpu.address_space<workgroup>>)
+  // CHECK:           linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"], library_call = ""} ins(%[[VAL_24]], %[[VAL_43]] : memref<?x?xf32, strided<[?, 1]>, #gpu.address_space<workgroup>>, memref<?x?xf32, strided<[?, 1]>, #gpu.address_space<workgroup>>) outs(%[[VAL_62]] : memref<?x?xf32, strided<[?, 1]>, #gpu.address_space<workgroup>>) {
   // CHECK:           ^bb0(%[[VAL_63:.*]]: f32, %[[VAL_64:.*]]: f32, %[[VAL_65:.*]]: f32):
   // CHECK:             %[[VAL_66:.*]] = arith.addf %[[VAL_63]], %[[VAL_64]] : f32
   // CHECK:             linalg.yield %[[VAL_66]] : f32
@@ -372,7 +372,7 @@ func.func @linalg_generic_update_all_function_inputs_outputs(%arg0: memref<3x4xf
     linalg.yield %1 : f32
   }
 
-  // CHECK:           linalg.copy ins(%[[VAL_62]] : memref<?x?xf32, strided<[?, 1], offset: ?>, #gpu.address_space<workgroup>>) outs(%[[VAL_5]] : memref<4x3xf32, strided<[4, 1]>, 1>)
+  // CHECK:           linalg.copy ins(%[[VAL_62]] : memref<?x?xf32, strided<[?, 1]>, #gpu.address_space<workgroup>>) outs(%[[VAL_5]] : memref<4x3xf32, strided<[4, 1]>, 1>)
   // CHECK:           memref.dealloc %[[VAL_22]] : memref<48xi8, #gpu.address_space<workgroup>>
   // CHECK:           memref.dealloc %[[VAL_41]] : memref<48xi8, #gpu.address_space<workgroup>>
   // CHECK:           memref.dealloc %[[VAL_60]] : memref<48xi8, #gpu.address_space<workgroup>>
diff --git a/mlir/test/Dialect/Linalg/transform-promotion.mlir b/mlir/test/Dialect/Linalg/transform-promotion.mlir
index d6112db0f7772..7c4cd623c742d 100644
--- a/mlir/test/Dialect/Linalg/transform-promotion.mlir
+++ b/mlir/test/Dialect/Linalg/transform-promotion.mlir
@@ -42,15 +42,15 @@ func.func @promote_subview_matmul(%arg0: memref<?x?xf32, strided<[?, 1], offset:
 // CHECK:               %[[a0:.*]] = memref.alloc() : memref<32000000xi8>
 // CHECK:               %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref<?x?xf32>
 // CHECK:               %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1]
-// CHECK-SAME:            memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+// CHECK-SAME:            memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>>
 // CHECK:               %[[a1:.*]] = memref.alloc() : memref<48000000xi8>
 // CHECK:               %[[v1:.*]] = memref.view %[[a1]]{{.*}} : memref<48000000xi8> to memref<?x?xf32>
 // CHECK:               %[[l1:.*]] = memref.subview %[[v1]][0, 0] [%{{.*}}, %{{.*}}] [1, 1]
-// CHECK-SAME:            memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+// CHECK-SAME:            memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>>
 // CHECK:               %[[a2:.*]] = memref.alloc() : memref<24000000xi8>
 // CHECK:               %[[v2:.*]] = memref.view %[[a2]]{{.*}} : memref<24000000xi8> to memref<?x?xf32>
 // CHECK:               %[[l2:.*]] = memref.subview %[[v2]][0, 0] [%{{.*}}, %{{.*}}] [1, 1]
-// CHECK-SAME:            memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+// CHECK-SAME:            memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>>
 // CHECK:               linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>)
 // CHECK:               linalg.copy ins(%[[s1]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l1]] : memref<?x?xf32, strided{{.*}}>)
 // CHECK:               linalg.copy ins(%[[s2]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l2]] : memref<?x?xf32, strided{{.*}}>)
@@ -110,7 +110,7 @@ func.func @promote_first_subview_matmul(%arg0: memref<?x?xf32, strided<[?, 1], o
 // CHECK:         %[[s2:.*]] = memref.subview {{.*}}: memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
 // CHECK:         %[[a0:.*]] = memref.alloc() : memref<32000000xi8>
 // CHECK:         %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref<?x?xf32>
-// CHECK:         %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+// CHECK:         %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>>
 // CHECK-NOT:     memref.alloc
 // CHECK-NOT:     memref.view
 // CHECK-NOT:     memref.subview
@@ -147,7 +147,7 @@ func.func @aligned_promote_fill(%arg0: memref<?x?xf32, strided<[?, 1], offset: ?
 // CHECK:         %[[s0:.*]] = memref.subview {{.*}}: memref<?x?xf32, strided{{.*}}> to memref<?x?xf32, strided{{.*}}>
 // CHECK:         %[[a0:.*]] = memref.alloc() {alignment = 32 : i64} : memref<32000000xi8>
 // CHECK:         %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref<?x?xf32>
-// CHECK:         %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+// CHECK:         %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1]>>
 // CHECK:         linalg.fill ins({{.*}} : f32) outs(%[[v0]] : memref<?x?xf32>)
 // CHECK:         linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>)
 // CHECK:         linalg.fill ins(%[[cf]] : f32) outs(%[[v0]] : memref<?x?xf32>)
@@ -180,7 +180,7 @@ func.func @aligned_promote_fill_complex(%arg0: memref<?x?xcomplex<f32>, strided<
 // CHECK:         %[[s0:.*]] = memref.subview {{.*}}: memref<?x?xcomplex<f32>, strided{{.*}}> to memref<?x?xcomplex<f32>, strided{{.*}}>
 // CHECK:         %[[a0:.*]] = memref.alloc() {alignment = 32 : i64} : memref<64000000xi8>
 // CHECK:         %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<64000000xi8> to memref<?x?xcomplex<f32>>
-// CHECK:         %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xcomplex<f32>> to memref<?x?xcomplex<f32>, strided<[?, 1], offset: ?>>
+// CHECK:         %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xcomplex<f32>> to memref<?x?xcomplex<f32>, strided<[?, 1]>>
 // CHECK:         linalg.fill ins({{.*}} : complex<f32>) outs(%[[v0]] : memref<?x?xcomplex<f32>>)
 // CHECK:         linalg.copy ins(%[[s0]] : memref<?x?xcomplex<f32>, strided{{.*}}>) outs(%[[l0]] : memref<?x?xcomplex<f32>, strided{{.*}}>)
 // CHECK:         linalg.fill ins(%[[cc]] : complex<f32>) outs(%[[v0]] : memref<?x?xcomplex<f32>>)
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index c4ff6480a4ce5..b15af9baca7dc 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -32,15 +32,14 @@ func.func @collapse_expand_rank0_cancel(%arg0 : memref<1x1xi8>) -> memref<1x1xi8
 // CHECK-LABEL: func @subview_of_size_memcast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
 //       CHECK:   %[[S:.+]] = memref.subview %[[ARG0]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, strided{{.*}}>
-//       CHECK:   %[[M:.+]] = memref.cast %[[S]] : memref<16x32xi8, {{.*}}> to memref<16x32xi8, strided{{.*}}>
-//       CHECK:   return %[[M]] : memref<16x32xi8, strided{{.*}}>
+//       CHECK:   return %[[S]] : memref<16x32xi8, strided{{.*}}>
 func.func @subview_of_size_memcast(%arg : memref<4x6x16x32xi8>) ->
-  memref<16x32xi8, strided<[32, 1], offset: ?>>{
+  memref<16x32xi8, strided<[32, 1], offset: 512>>{
   %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
   %1 = memref.subview %0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] :
     memref<?x?x16x32xi8> to
-    memref<16x32xi8, strided<[32, 1], offset: ?>>
-  return %1 : memref<16x32xi8, strided<[32, 1], offset: ?>>
+    memref<16x32xi8, strided<[32, 1], offset: 512>>
+  return %1 : memref<16x32xi8, strided<[32, 1], offset: 512>>
 }
 
 // -----
@@ -643,13 +642,13 @@ func.func @no_fold_subv...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/96421


More information about the Mlir-commits mailing list