[Mlir-commits] [mlir] andrzej/sme/remove rank 1 (PR #135396)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 11 09:24:14 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
- [mlir][vector] Tighten the semantics of vector.{load|store}
- [mlir][ArmSME] Audit arm_sme.tile_store
---
Patch is 24.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135396.diff
10 Files Affected:
- (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td (+2-2)
- (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+6-7)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+7)
- (modified) mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir (-12)
- (modified) mlir/test/Dialect/ArmSME/invalid.mlir (+18)
- (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+45-22)
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+28-4)
- (modified) mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir (+4-4)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/transpose.mlir (+4-4)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir (+42-39)
``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 6fd992afbf043..23eab706c856d 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -369,7 +369,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
```
}];
let arguments = (ins
- Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
+ Arg<MemRefRankOf<[AnyType], [2]>, "the reference to load from", [MemRead]>:$base,
Variadic<Index>:$indices,
Optional<AnyType>:$padding, Optional<AnyVectorOfNonZeroRank>:$mask,
ArmSME_TileSliceLayoutAttr:$layout
@@ -443,7 +443,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
```
}];
let arguments = (ins SMETile:$valueToStore,
- Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
+ Arg<MemRefRankOf<[AnyType], [2]>, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices, Optional<AnyVectorOfNonZeroRank>:$mask,
ArmSME_TileSliceLayoutAttr:$layout
);
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 6ed29903ea407..9bdafb7d8c501 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -33,20 +33,15 @@ SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
Value tileSliceIndex,
Value tileSliceNumElts, Location loc,
PatternRewriter &rewriter) {
- assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
+ assert(rank == 2 && "memref has unexpected rank!");
SmallVector<Value, 2> outIndices;
auto tileSliceOffset = tileSliceIndex;
- if (rank == 1)
- tileSliceOffset =
- rewriter.create<arith::MulIOp>(loc, tileSliceOffset, tileSliceNumElts);
auto baseIndexPlusTileSliceOffset =
rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
outIndices.push_back(baseIndexPlusTileSliceOffset);
-
- if (rank == 2)
- outIndices.push_back(indices[1]);
+ outIndices.push_back(indices[1]);
return outIndices;
}
@@ -60,6 +55,10 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
makeLoopBody) {
PatternRewriter::InsertionGuard guard(rewriter);
+ // TODO: This case should be captured and rejected by a verifier.
+ if (memrefIndices.size() != 2)
+ return rewriter.notifyMatchFailure(loc, "invalid number of indices");
+
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
auto vscale =
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 98d98f067de14..8b70a6b60a1ec 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5099,6 +5099,10 @@ LogicalResult vector::LoadOp::verify() {
if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
return failure();
+ if (memRefTy.getRank() < resVecTy.getRank())
+ return emitOpError(
+ "destination memref has lower rank than the result vector");
+
// Checks for vector memrefs.
Type memElemTy = memRefTy.getElementType();
if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
@@ -5131,6 +5135,9 @@ LogicalResult vector::StoreOp::verify() {
if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
return failure();
+ if (memRefTy.getRank() < valueVecTy.getRank())
+ return emitOpError("source memref has lower rank than the vector to store");
+
// Checks for vector memrefs.
Type memElemTy = memRefTy.getElementType();
if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index 0f973af799634..c8a434bb8f5de 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -718,18 +718,6 @@ func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16
// -----
-// CHECK-LABEL: @vector_load_i8_from_rank_1_memref(
-// CHECK-SAME: %[[MEMREF:.*]]: memref<?xi8>)
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: arm_sme.tile_load %[[MEMREF]][%[[C0]]] : memref<?xi8>, vector<[16]x[16]xi8>
-func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16]x[16]xi8> {
- %c0 = arith.constant 0 : index
- %tile = vector.load %arg0[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
- return %tile : vector<[16]x[16]xi8>
-}
-
-// -----
-
// CHECK-LABEL: @vector_load_i16(
// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 700b2412ff7a7..c015fe7cf1641 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -111,6 +111,15 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
return
}
+// -----
+
+func.func @arm_sme_tile_load__bad_memref_rank(%src : memref<?xf64>, %pad : f64) {
+ %c0 = arith.constant 0 : index
+ // expected-error at +1 {{op operand #0 must be 2D memref of any type values, but got 'memref<?xf64>'}}
+ %tile = arm_sme.tile_load %src[%c0], %pad, : memref<?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
//===----------------------------------------------------------------------===//
// arm_sme.load_tile_slice
//===----------------------------------------------------------------------===//
@@ -138,6 +147,15 @@ func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask
return
}
+// -----
+
+func.func @arm_sme_tile_store__bad_memref_rank(%tile : vector<[16]x[16]xi8>, %dest : memref<?xi8>) {
+ %c0 = arith.constant 0 : index
+ // expected-error at +1 {{op operand #1 must be 2D memref of any type values, but got 'memref<?xi8>'}}
+ arm_sme.tile_store %tile, %dest[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
//===----------------------------------------------------------------------===//
// arm_sme.store_tile_slice
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 067cdb5c5fd20..3160fd9c65c04 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -819,18 +819,29 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind
// -----
-func.func @fold_vector_load_subview(
- %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
- %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
- %1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32>
- return %1 : vector<12x32xf32>
+func.func @fold_vector_load_subview(%src : memref<24x64xf32>,
+ %off1 : index,
+ %off2 : index,
+ %dim1 : index,
+ %dim2 : index,
+ %idx : index) -> vector<12x32xf32> {
+
+ %0 = memref.subview %src[%off1, %off2][%dim1, %dim2][1, 1] : memref<24x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
+ %1 = vector.load %0[%idx, %idx] : memref<?x?xf32, strided<[64, 1], offset: ?>>, vector<12x32xf32>
+ return %1 : vector<12x32xf32>
}
-// CHECK: func @fold_vector_load_subview
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
-// CHECK: vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<12x32xf32>, vector<12x32xf32>
+// CHECK: #[[$ATTR_46:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func.func @fold_vector_load_subview(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]*]]: memref<24x64xf32>,
+// CHECK-SAME: %[[OFF_1:[a-zA-Z0-9$._-]*]]: index,
+// CHECK-SAME: %[[OFF_2:[a-zA-Z0-9$._-]*]]: index,
+// CHECK-SAME: %[[DIM_1:[a-zA-Z0-9$._-]*]]: index,
+// CHECK-SAME: %[[DIM_2:[a-zA-Z0-9$._-]*]]: index,
+// CHECK-SAME: %[[IDX:[a-zA-Z0-9$._-]*]]: index) -> vector<12x32xf32> {
+// CHECK: %[[VAL_6:.*]] = affine.apply #[[$ATTR_46]](){{\[}}%[[OFF_1]], %[[IDX]]]
+// CHECK: %[[VAL_7:.*]] = affine.apply #[[$ATTR_46]](){{\[}}%[[OFF_2]], %[[IDX]]]
+// CHECK: %[[VAL_8:.*]] = vector.load %[[SRC]]{{\[}}%[[VAL_6]], %[[VAL_7]]] : memref<24x64xf32>, vector<12x32xf32>
// -----
@@ -851,20 +862,32 @@ func.func @fold_vector_maskedload_subview(
// -----
-func.func @fold_vector_store_subview(
- %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<2x32xf32>) -> () {
- %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
- vector.store %arg3, %0[] : memref<f32, strided<[], offset: ?>>, vector<2x32xf32>
- return
+func.func @fold_vector_store_subview(%src : memref<24x64xf32>,
+ %off1 : index,
+ %off2 : index,
+ %vec: vector<2x32xf32>,
+ %idx : index,
+ %dim1 : index,
+ %dim2 : index) -> () {
+
+ %0 = memref.subview %src[%off1, %off2][%dim1, %dim2][1, 1] : memref<24x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
+ vector.store %vec, %0[%idx, %idx] : memref<?x?xf32, strided<[64, 1], offset: ?>> , vector<2x32xf32>
+ return
}
-// CHECK: func @fold_vector_store_subview
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<2x32xf32>
-// CHECK: vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<12x32xf32>, vector<2x32xf32>
-// CHECK: return
+// CHECK: #[[$ATTR_47:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+
+// CHECK-LABEL: func.func @fold_vector_store_subview(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]*]]: memref<24x64xf32>,
+// CHECK-SAME: %[[OFF1:[a-zA-Z0-9$._-]*]]: index,
+// CHECK-SAME: %[[OFF_2:[a-zA-Z0-9$._-]*]]: index,
+// CHECK-SAME: %[[VEC:[a-zA-Z0-9$._-]*]]: vector<2x32xf32>,
+// CHECK-SAME: %[[IDX:[a-zA-Z0-9$._-]*]]: index,
+// CHECK-SAME: %[[VAL_5:[a-zA-Z0-9$._-]*]]: index,
+// CHECK-SAME: %[[VAL_6:[a-zA-Z0-9$._-]*]]: index) {
+// CHECK: %[[VAL_7:.*]] = affine.apply #[[$ATTR_47]](){{\[}}%[[OFF1]], %[[IDX]]]
+// CHECK: %[[VAL_8:.*]] = affine.apply #[[$ATTR_47]](){{\[}}%[[OFF_2]], %[[IDX]]]
+// CHECK: vector.store %[[VEC]], %[[SRC]]{{\[}}%[[VAL_7]], %[[VAL_8]]] : memref<24x64xf32>, vector<2x32xf32>
// -----
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index ea6d0021391fb..f7192fbf68b4e 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1743,13 +1743,11 @@ func.func @invalid_outerproduct(%src : memref<?xf32>) {
// -----
-func.func @invalid_outerproduct1(%src : memref<?xf32>) {
+func.func @invalid_outerproduct1(%src : memref<?xf32>, %lhs : vector<[4]x[4]xf32>, %rhs : vector<[4]xf32>) {
%idx = arith.constant 0 : index
- %0 = vector.load %src[%idx] : memref<?xf32>, vector<[4]x[4]xf32>
- %1 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
// expected-error @+1 {{'vector.outerproduct' op expected 1-d vector for operand #1}}
- %op = vector.outerproduct %0, %1 : vector<[4]x[4]xf32>, vector<[4]xf32>
+ %op = vector.outerproduct %lhs, %rhs : vector<[4]x[4]xf32>, vector<[4]xf32>
}
// -----
@@ -1870,3 +1868,29 @@ func.func @flat_transpose_scalable(%arg0: vector<[16]xf32>) -> vector<[16]xf32>
: vector<[16]xf32> -> vector<[16]xf32>
return %0 : vector<[16]xf32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.load
+//===----------------------------------------------------------------------===//
+
+func.func @vector_load(%src : memref<?xi8>) {
+ %c0 = arith.constant 0 : index
+ // expected-error @+1 {{'vector.load' op destination memref has lower rank than the result vector}}
+ %0 = vector.load %src[%c0] : memref<?xi8>, vector<16x16xi8>
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.store
+//===----------------------------------------------------------------------===//
+
+func.func @vector_store(%dest : memref<?xi8>, %vec : vector<16x16xi8>) {
+ %c0 = arith.constant 0 : index
+ // expected-error @+1 {{'vector.store' op source memref has lower rank than the vector to store}}
+ vector.store %vec, %dest[%c0] : memref<?xi8>, vector<16x16xi8>
+ return
+}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index fd50acf03e79b..511ab70f35086 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -2,8 +2,8 @@
// CHECK-LABEL: func @vector_transfer_ops_0d_memref(
// CHECK-SAME: %[[MEM:.*]]: memref<f32>
-// CHECK-SAME: %[[VEC:.*]]: vector<1x1x1xf32>
-func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf32>) {
+// CHECK-SAME: %[[VEC:.*]]: vector<f32>
+func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<f32>) {
%f0 = arith.constant 0.0 : f32
// CHECK-NEXT: %[[S:.*]] = vector.load %[[MEM]][] : memref<f32>, vector<f32>
@@ -12,8 +12,8 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf
// CHECK-NEXT: vector.store %[[S]], %[[MEM]][] : memref<f32>, vector<f32>
vector.transfer_write %0, %mem[] : vector<f32>, memref<f32>
-// CHECK-NEXT: vector.store %[[VEC]], %[[MEM]][] : memref<f32>, vector<1x1x1xf32>
- vector.store %vec, %mem[] : memref<f32>, vector<1x1x1xf32>
+// CHECK-NEXT: vector.store %[[VEC]], %[[MEM]][] : memref<f32>, vector<f32>
+ vector.store %vec, %mem[] : memref<f32>, vector<f32>
return
}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/transpose.mlir
index ff20f99b63cd1..b44658eef4e11 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/transpose.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/transpose.mlir
@@ -17,7 +17,7 @@ func.func @entry() {
%za_s_size = arith.muli %svl_s, %svl_s : index
// Allocate memory.
- %mem1 = memref.alloca(%za_s_size) : memref<?xi32>
+ %mem1 = memref.alloca(%za_s_size, %svl_s) : memref<?x?xi32>
// Fill each "row" of "mem1" with row number.
//
@@ -29,15 +29,15 @@ func.func @entry() {
// 3, 3, 3, 3
//
%init_0 = arith.constant 0 : i32
- scf.for %i = %c0 to %za_s_size step %svl_s iter_args(%val = %init_0) -> (i32) {
+ scf.for %i = %c0 to %svl_s step %c1 iter_args(%val = %init_0) -> (i32) {
%splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
- vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+ vector.store %splat_val, %mem1[%i, %c0] : memref<?x?xi32>, vector<[4]xi32>
%val_next = arith.addi %val, %c1_i32 : i32
scf.yield %val_next : i32
}
// Load tile from "mem1".
- %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+ %tile = vector.load %mem1[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
// Transpose tile.
%transposed_tile = vector.transpose %tile, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
index 6e25bee65f095..09d68661c6e9d 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -34,11 +34,11 @@ func.func @za0_d_f64() -> i32 {
// 3.1, 3.1, 3.1, 3.1
//
%tilesize = arith.muli %svl_d, %svl_d : index
- %mem1 = memref.alloca(%tilesize) : memref<?xf64>
+ %mem1 = memref.alloca(%svl_d, %svl_d) : memref<?x?xf64>
%init_0 = arith.constant 0.1 : f64
- scf.for %i = %c0 to %tilesize step %svl_d iter_args(%val = %init_0) -> (f64) {
+ scf.for %i = %c0 to %svl_d step %c1_index iter_args(%val = %init_0) -> (f64) {
%splat_val = vector.broadcast %val : f64 to vector<[2]xf64>
- vector.store %splat_val, %mem1[%i] : memref<?xf64>, vector<[2]xf64>
+ vector.store %splat_val, %mem1[%i, %c0] : memref<?x?xf64>, vector<[2]xf64>
%val_next = arith.addf %val, %c1_f64 : f64
scf.yield %val_next : f64
}
@@ -48,27 +48,29 @@ func.func @za0_d_f64() -> i32 {
//
// CHECK-ZA0_D: ( 0.1, 0.1
// CHECK-ZA0_D-NEXT: ( 1.1, 1.1
- scf.for %i = %c0 to %tilesize step %svl_d {
- %tileslice = vector.load %mem1[%i] : memref<?xf64>, vector<[2]xf64>
+ scf.for %i = %c0 to %svl_d step %c1_index {
+ %tileslice = vector.load %mem1[%i, %c0] : memref<?x?xf64>, vector<[2]xf64>
vector.print %tileslice : vector<[2]xf64>
}
// Load ZA0.D from "mem1"
- %za0_d = vector.load %mem1[%c0] : memref<?xf64>, vector<[2]x[2]xf64>
+ %za0_d = vector.load %mem1[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
// Allocate "mem2" to store ZA0.D to
- %mem2 = memref.alloca(%tilesize) : memref<?xf64>
+ %mem2 = memref.alloca(%svl_d, %svl_d) : memref<?x?xf64>
// Zero "mem2"
- scf.for %i = %c0 to %tilesize step %c1_index {
- memref.store %c0_f64, %mem2[%i] : memref<?xf64>
+ scf.for %i = %c0 to %svl_d step %c1_index {
+ scf.for %j = %c0 to %svl_d step %c1_index {
+ memref.store %c0_f64, %mem2[%i, %j] : memref<?x?xf64>
+ }
}
// Verify "mem2" is zeroed by doing an add reduction with initial value of
// zero
%init_0_f64 = arith.constant 0.0 : f64
- %add_reduce = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_0_f64) -> (f64) {
- %row = vector.load %mem2[%vnum] : memref<?xf64>, vector<[2]xf64>
+ %add_reduce = scf.for %vnum = %c0 to %svl_d step %c1_index iter_args(%iter = %init_0_f64) -> (f64) {
+ %row = vector.load %mem2[%vnum, %c0] : memref<?x?xf64>, vector<[2]xf64>
%inner_add_reduce = scf.for %offset = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_0_f64) -> (f64) {
%t = vector.extractelement %row[%offset : index] : vector<[2]xf64>
@@ -88,16 +90,16 @@ func.func @za0_d_f64() -> i32 {
//
// CHECK-ZA0_D-NEXT: ( 0, 0
// CHECK-ZA0_D-NEXT: ( 0, 0
- scf.for %i = %c0 to %tilesize step %svl_d {
- %tileslice = vector.load %mem2[%i] : memref<?xf64>, vector<[2]xf64>
+ scf.for %i = %c0 to %svl_d step %c1_index{
+ %tileslice = vector.load %mem2[%i, %c0] : memref<?x?xf64>, vector<[2]xf64>
vector.print %tileslice : vector<[2]xf64>
}
// Verify "mem1" != "mem2"
%init_1 = arith.constant 1 : i64
- %mul_reduce_0 = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_1) -> (i64) {
- %row_1 = vector.load %mem1[%vnum] : memref<?xf64>, vector<[2]xf64>
- %row_2 = vector.load %mem2[%vnum] : memref<?xf64>, vector<[2]xf64>
+ %mul_reduce_0 = scf.for %vnum = %c0 to %svl_d step %c1_index iter_args(%iter = %init_1) -> (i64) {
+ %row_1 = vector.load %mem1[%vnum, %c0] : memref<?x?xf64>, vector<[2]xf64>
+ %row_2 = vector.load %mem2[%vnum, %c0] : memref<?x?xf64>, vector<[2]xf64>
%cmp = arith.cmpf one, %row_1, %row_2 : vector<[2]xf64>
%inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) {
@@ -115,12 +117,12 @@ func.func @za0_d_f64() -> i32 {
vector.print %mul_reduce_0 : i64
// Store ZA0.D to "mem2"
- vector.store %za0_d, %mem2[%c0] : memref<?xf64>, vector<[2]x[2]xf64>
+ vector.store %za0_d, %mem2[%c0, %c0] :...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/135396
More information about the Mlir-commits
mailing list