[Mlir-commits] [mlir] andrzej/sme/remove rank 1 (PR #135396)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 11 09:24:16 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

- [mlir][vector] Tighten the semantics of vector.{load|store}
- [mlir][ArmSME] Audit arm_sme.tile_store


---

Patch is 24.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135396.diff


10 Files Affected:

- (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td (+2-2) 
- (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+6-7) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+7) 
- (modified) mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir (-12) 
- (modified) mlir/test/Dialect/ArmSME/invalid.mlir (+18) 
- (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+45-22) 
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+28-4) 
- (modified) mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir (+4-4) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/transpose.mlir (+4-4) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir (+42-39) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 6fd992afbf043..23eab706c856d 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -369,7 +369,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
     ```
   }];
   let arguments = (ins
-    Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
+    Arg<MemRefRankOf<[AnyType], [2]>, "the reference to load from", [MemRead]>:$base,
     Variadic<Index>:$indices,
     Optional<AnyType>:$padding, Optional<AnyVectorOfNonZeroRank>:$mask,
     ArmSME_TileSliceLayoutAttr:$layout
@@ -443,7 +443,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
     ```
   }];
   let arguments = (ins SMETile:$valueToStore,
-    Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
+    Arg<MemRefRankOf<[AnyType], [2]>, "the reference to store to", [MemWrite]>:$base,
     Variadic<Index>:$indices, Optional<AnyVectorOfNonZeroRank>:$mask,
     ArmSME_TileSliceLayoutAttr:$layout
   );
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 6ed29903ea407..9bdafb7d8c501 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -33,20 +33,15 @@ SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
                                        Value tileSliceIndex,
                                        Value tileSliceNumElts, Location loc,
                                        PatternRewriter &rewriter) {
-  assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
+  assert(rank == 2 && "memref has unexpected rank!");
   SmallVector<Value, 2> outIndices;
 
   auto tileSliceOffset = tileSliceIndex;
-  if (rank == 1)
-    tileSliceOffset =
-        rewriter.create<arith::MulIOp>(loc, tileSliceOffset, tileSliceNumElts);
 
   auto baseIndexPlusTileSliceOffset =
       rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
   outIndices.push_back(baseIndexPlusTileSliceOffset);
-
-  if (rank == 2)
-    outIndices.push_back(indices[1]);
+  outIndices.push_back(indices[1]);
 
   return outIndices;
 }
@@ -60,6 +55,10 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
         makeLoopBody) {
   PatternRewriter::InsertionGuard guard(rewriter);
 
+  // TODO: This case should be captured and rejected by a verifier.
+  if (memrefIndices.size() != 2)
+    return rewriter.notifyMatchFailure(loc, "invalid number of indices");
+
   auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
       loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
   auto vscale =
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 98d98f067de14..8b70a6b60a1ec 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5099,6 +5099,10 @@ LogicalResult vector::LoadOp::verify() {
   if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
     return failure();
 
+  if (memRefTy.getRank() < resVecTy.getRank())
+    return emitOpError(
+        "destination memref has lower rank than the result vector");
+
   // Checks for vector memrefs.
   Type memElemTy = memRefTy.getElementType();
   if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
@@ -5131,6 +5135,9 @@ LogicalResult vector::StoreOp::verify() {
   if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
     return failure();
 
+  if (memRefTy.getRank() < valueVecTy.getRank())
+    return emitOpError("source memref has lower rank than the vector to store");
+
   // Checks for vector memrefs.
   Type memElemTy = memRefTy.getElementType();
   if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index 0f973af799634..c8a434bb8f5de 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -718,18 +718,6 @@ func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16
 
 // -----
 
-// CHECK-LABEL: @vector_load_i8_from_rank_1_memref(
-// CHECK-SAME:                                     %[[MEMREF:.*]]: memref<?xi8>)
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: arm_sme.tile_load %[[MEMREF]][%[[C0]]] : memref<?xi8>, vector<[16]x[16]xi8>
-func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16]x[16]xi8> {
-  %c0 = arith.constant 0 : index
-  %tile = vector.load %arg0[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
-  return %tile : vector<[16]x[16]xi8>
-}
-
-// -----
-
 // CHECK-LABEL: @vector_load_i16(
 // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
 func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 700b2412ff7a7..c015fe7cf1641 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -111,6 +111,15 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
   return
 }
 
+// -----
+
+func.func @arm_sme_tile_load__bad_memref_rank(%src : memref<?xf64>, %pad : f64) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{op operand #0 must be 2D memref of any type values, but got 'memref<?xf64>'}}
+  %tile = arm_sme.tile_load %src[%c0], %pad, : memref<?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // arm_sme.load_tile_slice
 //===----------------------------------------------------------------------===//
@@ -138,6 +147,15 @@ func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask
   return
 }
 
+// -----
+
+func.func @arm_sme_tile_store__bad_memref_rank(%tile : vector<[16]x[16]xi8>, %dest : memref<?xi8>) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{op operand #1 must be 2D memref of any type values, but got 'memref<?xi8>'}}
+  arm_sme.tile_store %tile, %dest[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // arm_sme.store_tile_slice
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 067cdb5c5fd20..3160fd9c65c04 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -819,18 +819,29 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind
 
 // -----
 
-func.func @fold_vector_load_subview(
-  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
-  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
-  %1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32>
-  return %1 : vector<12x32xf32>
+func.func @fold_vector_load_subview(%src : memref<24x64xf32>,
+                                    %off1 : index,
+                                    %off2 : index,
+                                    %dim1 : index,
+                                    %dim2 : index,
+                                    %idx : index) -> vector<12x32xf32> {
+
+    %0 = memref.subview %src[%off1, %off2][%dim1, %dim2][1, 1] : memref<24x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
+    %1 = vector.load %0[%idx, %idx] :  memref<?x?xf32, strided<[64, 1], offset: ?>>, vector<12x32xf32>
+    return %1 : vector<12x32xf32>
 }
 
-//      CHECK: func @fold_vector_load_subview
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
-//      CHECK:   vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] :  memref<12x32xf32>, vector<12x32xf32>
+// CHECK: #[[$ATTR_46:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL:   func.func @fold_vector_load_subview(
+// CHECK-SAME:      %[[SRC:[a-zA-Z0-9$._-]*]]: memref<24x64xf32>,
+// CHECK-SAME:      %[[OFF_1:[a-zA-Z0-9$._-]*]]: index,
+// CHECK-SAME:      %[[OFF_2:[a-zA-Z0-9$._-]*]]: index,
+// CHECK-SAME:      %[[DIM_1:[a-zA-Z0-9$._-]*]]: index,
+// CHECK-SAME:      %[[DIM_2:[a-zA-Z0-9$._-]*]]: index,
+// CHECK-SAME:      %[[IDX:[a-zA-Z0-9$._-]*]]: index) -> vector<12x32xf32> {
+// CHECK:           %[[VAL_6:.*]] = affine.apply #[[$ATTR_46]](){{\[}}%[[OFF_1]], %[[IDX]]]
+// CHECK:           %[[VAL_7:.*]] = affine.apply #[[$ATTR_46]](){{\[}}%[[OFF_2]], %[[IDX]]]
+// CHECK:           %[[VAL_8:.*]] = vector.load %[[SRC]]{{\[}}%[[VAL_6]], %[[VAL_7]]] : memref<24x64xf32>, vector<12x32xf32>
 
 // -----
 
@@ -851,20 +862,32 @@ func.func @fold_vector_maskedload_subview(
 
 // -----
 
-func.func @fold_vector_store_subview(
-  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<2x32xf32>) -> () {
-  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
-  vector.store %arg3, %0[] : memref<f32, strided<[], offset: ?>>, vector<2x32xf32>
-  return
+func.func @fold_vector_store_subview(%src : memref<24x64xf32>,
+                                     %off1 : index,
+                                     %off2 : index,
+                                     %vec: vector<2x32xf32>,
+                                     %idx : index,
+                                     %dim1 : index,
+                                     %dim2 : index) -> () {
+
+    %0 = memref.subview %src[%off1, %off2][%dim1, %dim2][1, 1] : memref<24x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
+    vector.store %vec, %0[%idx, %idx] : memref<?x?xf32, strided<[64, 1], offset: ?>> , vector<2x32xf32>
+    return
 }
 
-//      CHECK: func @fold_vector_store_subview
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<2x32xf32>
-//      CHECK:   vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] :  memref<12x32xf32>, vector<2x32xf32>
-//      CHECK:   return
+// CHECK: #[[$ATTR_47:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+
+// CHECK-LABEL:   func.func @fold_vector_store_subview(
+// CHECK-SAME:      %[[SRC:[a-zA-Z0-9$._-]*]]: memref<24x64xf32>,
+// CHECK-SAME:      %[[OFF1:[a-zA-Z0-9$._-]*]]: index,
+// CHECK-SAME:      %[[OFF_2:[a-zA-Z0-9$._-]*]]: index,
+// CHECK-SAME:      %[[VEC:[a-zA-Z0-9$._-]*]]: vector<2x32xf32>,
+// CHECK-SAME:      %[[IDX:[a-zA-Z0-9$._-]*]]: index,
+// CHECK-SAME:      %[[VAL_5:[a-zA-Z0-9$._-]*]]: index,
+// CHECK-SAME:      %[[VAL_6:[a-zA-Z0-9$._-]*]]: index) {
+//      CHECK:      %[[VAL_7:.*]] = affine.apply #[[$ATTR_47]](){{\[}}%[[OFF1]], %[[IDX]]]
+//      CHECK:      %[[VAL_8:.*]] = affine.apply #[[$ATTR_47]](){{\[}}%[[OFF_2]], %[[IDX]]]
+//      CHECK:      vector.store %[[VEC]], %[[SRC]]{{\[}}%[[VAL_7]], %[[VAL_8]]] : memref<24x64xf32>, vector<2x32xf32>
 
 // -----
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index ea6d0021391fb..f7192fbf68b4e 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1743,13 +1743,11 @@ func.func @invalid_outerproduct(%src : memref<?xf32>) {
 
 // -----
 
-func.func @invalid_outerproduct1(%src : memref<?xf32>) {
+func.func @invalid_outerproduct1(%src : memref<?xf32>, %lhs : vector<[4]x[4]xf32>, %rhs : vector<[4]xf32>) {
   %idx = arith.constant 0 : index
-  %0 = vector.load %src[%idx] : memref<?xf32>, vector<[4]x[4]xf32>
-  %1 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
 
   // expected-error @+1 {{'vector.outerproduct' op expected 1-d vector for operand #1}}
-  %op = vector.outerproduct %0, %1 : vector<[4]x[4]xf32>, vector<[4]xf32>
+  %op = vector.outerproduct %lhs, %rhs : vector<[4]x[4]xf32>, vector<[4]xf32>
 }
 
 // -----
@@ -1870,3 +1868,29 @@ func.func @flat_transpose_scalable(%arg0: vector<[16]xf32>) -> vector<[16]xf32>
      : vector<[16]xf32> -> vector<[16]xf32>
   return %0 : vector<[16]xf32>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.load
+//===----------------------------------------------------------------------===//
+
+func.func @vector_load(%src : memref<?xi8>) {
+  %c0 = arith.constant 0 : index
+  // expected-error @+1 {{'vector.load' op destination memref has lower rank than the result vector}}
+  %0 = vector.load %src[%c0] : memref<?xi8>, vector<16x16xi8>
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.store
+//===----------------------------------------------------------------------===//
+
+func.func @vector_store(%dest : memref<?xi8>, %vec : vector<16x16xi8>) {
+  %c0 = arith.constant 0 : index
+  // expected-error @+1 {{'vector.store' op source memref has lower rank than the vector to store}}
+  vector.store %vec, %dest[%c0] : memref<?xi8>, vector<16x16xi8>
+  return
+}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index fd50acf03e79b..511ab70f35086 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -2,8 +2,8 @@
 
 // CHECK-LABEL: func @vector_transfer_ops_0d_memref(
 //  CHECK-SAME:   %[[MEM:.*]]: memref<f32>
-//  CHECK-SAME:   %[[VEC:.*]]: vector<1x1x1xf32>
-func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf32>) {
+//  CHECK-SAME:   %[[VEC:.*]]: vector<f32>
+func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<f32>) {
     %f0 = arith.constant 0.0 : f32
 
 //  CHECK-NEXT:   %[[S:.*]] = vector.load %[[MEM]][] : memref<f32>, vector<f32>
@@ -12,8 +12,8 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf
 //  CHECK-NEXT:   vector.store %[[S]], %[[MEM]][] : memref<f32>, vector<f32>
     vector.transfer_write %0, %mem[] : vector<f32>, memref<f32>
 
-//  CHECK-NEXT:   vector.store %[[VEC]], %[[MEM]][] : memref<f32>, vector<1x1x1xf32>
-    vector.store %vec, %mem[] : memref<f32>, vector<1x1x1xf32>
+//  CHECK-NEXT:   vector.store %[[VEC]], %[[MEM]][] : memref<f32>, vector<f32>
+    vector.store %vec, %mem[] : memref<f32>, vector<f32>
 
     return
 }
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/transpose.mlir
index ff20f99b63cd1..b44658eef4e11 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/transpose.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/transpose.mlir
@@ -17,7 +17,7 @@ func.func @entry() {
   %za_s_size = arith.muli %svl_s, %svl_s : index
 
   // Allocate memory.
-  %mem1 = memref.alloca(%za_s_size) : memref<?xi32>
+  %mem1 = memref.alloca(%za_s_size, %svl_s) : memref<?x?xi32>
 
   // Fill each "row" of "mem1" with row number.
   //
@@ -29,15 +29,15 @@ func.func @entry() {
   //   3, 3, 3, 3
   //
   %init_0 = arith.constant 0 : i32
-  scf.for %i = %c0 to %za_s_size step %svl_s iter_args(%val = %init_0) -> (i32) {
+  scf.for %i = %c0 to %svl_s step %c1 iter_args(%val = %init_0) -> (i32) {
     %splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
-    vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+    vector.store %splat_val, %mem1[%i, %c0] : memref<?x?xi32>, vector<[4]xi32>
     %val_next = arith.addi %val, %c1_i32 : i32
     scf.yield %val_next : i32
   }
 
   // Load tile from "mem1".
-  %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+  %tile = vector.load %mem1[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
 
   // Transpose tile.
   %transposed_tile = vector.transpose %tile, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
index 6e25bee65f095..09d68661c6e9d 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -34,11 +34,11 @@ func.func @za0_d_f64() -> i32 {
   //   3.1, 3.1, 3.1, 3.1
   //
   %tilesize = arith.muli %svl_d, %svl_d : index
-  %mem1 = memref.alloca(%tilesize) : memref<?xf64>
+  %mem1 = memref.alloca(%svl_d, %svl_d) : memref<?x?xf64>
   %init_0 = arith.constant 0.1 : f64
-  scf.for %i = %c0 to %tilesize step %svl_d iter_args(%val = %init_0) -> (f64) {
+  scf.for %i = %c0 to %svl_d step %c1_index iter_args(%val = %init_0) -> (f64) {
     %splat_val = vector.broadcast %val : f64 to vector<[2]xf64>
-    vector.store %splat_val, %mem1[%i] : memref<?xf64>, vector<[2]xf64>
+    vector.store %splat_val, %mem1[%i, %c0] : memref<?x?xf64>, vector<[2]xf64>
     %val_next = arith.addf %val, %c1_f64 : f64
     scf.yield %val_next : f64
   }
@@ -48,27 +48,29 @@ func.func @za0_d_f64() -> i32 {
   //
   // CHECK-ZA0_D:      ( 0.1, 0.1
   // CHECK-ZA0_D-NEXT: ( 1.1, 1.1
-  scf.for %i = %c0 to %tilesize step %svl_d {
-    %tileslice = vector.load %mem1[%i] : memref<?xf64>, vector<[2]xf64>
+  scf.for %i = %c0 to %svl_d step %c1_index {
+    %tileslice = vector.load %mem1[%i, %c0] : memref<?x?xf64>, vector<[2]xf64>
     vector.print %tileslice : vector<[2]xf64>
   }
 
   // Load ZA0.D from "mem1"
-  %za0_d = vector.load %mem1[%c0] : memref<?xf64>, vector<[2]x[2]xf64>
+  %za0_d = vector.load %mem1[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
 
   // Allocate "mem2" to store ZA0.D to
-  %mem2 = memref.alloca(%tilesize) : memref<?xf64>
+  %mem2 = memref.alloca(%svl_d, %svl_d) : memref<?x?xf64>
 
   // Zero "mem2"
-  scf.for %i = %c0 to %tilesize step %c1_index {
-    memref.store %c0_f64, %mem2[%i] : memref<?xf64>
+  scf.for %i = %c0 to %svl_d step %c1_index {
+    scf.for %j = %c0 to %svl_d step %c1_index {
+      memref.store %c0_f64, %mem2[%i, %j] : memref<?x?xf64>
+    }
   }
 
   // Verify "mem2" is zeroed by doing an add reduction with initial value of
   // zero
   %init_0_f64 = arith.constant 0.0 : f64
-  %add_reduce = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_0_f64) -> (f64) {
-    %row = vector.load %mem2[%vnum] : memref<?xf64>, vector<[2]xf64>
+  %add_reduce = scf.for %vnum = %c0 to %svl_d step %c1_index iter_args(%iter = %init_0_f64) -> (f64) {
+    %row = vector.load %mem2[%vnum, %c0] : memref<?x?xf64>, vector<[2]xf64>
 
     %inner_add_reduce = scf.for %offset = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_0_f64) -> (f64) {
       %t = vector.extractelement %row[%offset : index] : vector<[2]xf64>
@@ -88,16 +90,16 @@ func.func @za0_d_f64() -> i32 {
   //
   // CHECK-ZA0_D-NEXT: ( 0, 0
   // CHECK-ZA0_D-NEXT: ( 0, 0
-  scf.for %i = %c0 to %tilesize step %svl_d {
-    %tileslice = vector.load %mem2[%i] : memref<?xf64>, vector<[2]xf64>
+  scf.for %i = %c0 to  %svl_d  step %c1_index{
+    %tileslice = vector.load %mem2[%i, %c0] : memref<?x?xf64>, vector<[2]xf64>
     vector.print %tileslice : vector<[2]xf64>
   }
 
   // Verify "mem1" != "mem2"
   %init_1 = arith.constant 1 : i64
-  %mul_reduce_0 = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_1) -> (i64) {
-    %row_1 = vector.load %mem1[%vnum] : memref<?xf64>, vector<[2]xf64>
-    %row_2 = vector.load %mem2[%vnum] : memref<?xf64>, vector<[2]xf64>
+  %mul_reduce_0 = scf.for %vnum = %c0 to %svl_d step %c1_index iter_args(%iter = %init_1) -> (i64) {
+    %row_1 = vector.load %mem1[%vnum, %c0] : memref<?x?xf64>, vector<[2]xf64>
+    %row_2 = vector.load %mem2[%vnum, %c0] : memref<?x?xf64>, vector<[2]xf64>
     %cmp = arith.cmpf one, %row_1, %row_2 : vector<[2]xf64>
 
     %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) {
@@ -115,12 +117,12 @@ func.func @za0_d_f64() -> i32 {
   vector.print %mul_reduce_0 : i64
 
   // Store ZA0.D to "mem2"
-  vector.store %za0_d, %mem2[%c0] : memref<?xf64>, vector<[2]x[2]xf64>
+  vector.store %za0_d, %mem2[%c0, %c0] :...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/135396


More information about the Mlir-commits mailing list