[Mlir-commits] [mlir] [mlir][Vector] Remove VectorLoadToMemrefLoadLowering and VectorStoreToMemrefStoreLowering (PR #121454)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 1 22:32:06 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: None (meltq)

<details>
<summary>Changes</summary>

0-d vectors are supported now and so these patterns are no longer required. This covers a part of this issue https://github.com/llvm/llvm-project/issues/112913

---
Full diff: https://github.com/llvm/llvm-project/pull/121454.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (-57) 
- (modified) mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir (+10-14) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index f9428a4ce28640..156e33257d8716 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -492,60 +492,6 @@ struct TransferReadToVectorLoadLowering
   std::optional<unsigned> maxTransferRank;
 };
 
-/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
-// TODO: we shouldn't cross the vector/scalar domains just for this
-// but atm we lack the infra to avoid it. Possible solutions include:
-// - go directly to LLVM + bitcast
-// - introduce a bitcast op and likely a new pointer dialect
-// - let memref.load/store additionally support the 0-d vector case
-// There are still deeper data layout issues lingering even in this
-// trivial case (for architectures for which this matters).
-struct VectorLoadToMemrefLoadLowering
-    : public OpRewritePattern<vector::LoadOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
-                                PatternRewriter &rewriter) const override {
-    auto vecType = loadOp.getVectorType();
-    if (vecType.getNumElements() != 1)
-      return rewriter.notifyMatchFailure(loadOp, "not a single element vector");
-
-    auto memrefLoad = rewriter.create<memref::LoadOp>(
-        loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
-                                                     memrefLoad);
-    return success();
-  }
-};
-
-/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
-struct VectorStoreToMemrefStoreLowering
-    : public OpRewritePattern<vector::StoreOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
-                                PatternRewriter &rewriter) const override {
-    auto vecType = storeOp.getVectorType();
-    if (vecType.getNumElements() != 1)
-      return rewriter.notifyMatchFailure(storeOp, "not single element vector");
-
-    Value extracted;
-    if (vecType.getRank() == 0) {
-      // TODO: Unifiy once ExtractOp supports 0-d vectors.
-      extracted = rewriter.create<vector::ExtractElementOp>(
-          storeOp.getLoc(), storeOp.getValueToStore());
-    } else {
-      SmallVector<int64_t> indices(vecType.getRank(), 0);
-      extracted = rewriter.create<vector::ExtractOp>(
-          storeOp.getLoc(), storeOp.getValueToStore(), indices);
-    }
-
-    rewriter.replaceOpWithNewOp<memref::StoreOp>(
-        storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
-    return success();
-  }
-};
-
 /// Progressive lowering of transfer_write. This pattern supports lowering of
 /// `vector.transfer_write` to `vector.store` if all of the following hold:
 /// - Stride of most minor memref dimension must be 1.
@@ -645,7 +591,4 @@ void mlir::vector::populateVectorTransferLoweringPatterns(
   patterns.add<TransferReadToVectorLoadLowering,
                TransferWriteToVectorStoreLowering>(patterns.getContext(),
                                                    maxTransferRank, benefit);
-  patterns
-      .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
-          patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index f90111b4c88618..fd50acf03e79b1 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -6,16 +6,13 @@
 func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf32>) {
     %f0 = arith.constant 0.0 : f32
 
-//  CHECK-NEXT:   %[[S:.*]] = memref.load %[[MEM]][] : memref<f32>
-//  CHECK-NEXT:   %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<f32>
+//  CHECK-NEXT:   %[[S:.*]] = vector.load %[[MEM]][] : memref<f32>, vector<f32>
     %0 = vector.transfer_read %mem[], %f0 : memref<f32>, vector<f32>
 
-//  CHECK-NEXT:   %[[SS:.*]] = vector.extractelement %[[V]][] : vector<f32>
-//  CHECK-NEXT:   memref.store %[[SS]], %[[MEM]][] : memref<f32>
+//  CHECK-NEXT:   vector.store %[[S]], %[[MEM]][] : memref<f32>, vector<f32>
     vector.transfer_write %0, %mem[] : vector<f32>, memref<f32>
 
-//  CHECK-NEXT:   %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
-//  CHECK-NEXT:   memref.store %[[VV]], %[[MEM]][] : memref<f32>
+//  CHECK-NEXT:   vector.store %[[VEC]], %[[MEM]][] : memref<f32>, vector<1x1x1xf32>
     vector.store %vec, %mem[] : memref<f32>, vector<1x1x1xf32>
 
     return
@@ -191,8 +188,8 @@ func.func @transfer_perm_map(%mem : memref<8x8xf32>, %idx : index) -> vector<4xf
 // CHECK-LABEL:   func @transfer_broadcasting(
 // CHECK-SAME:      %[[MEM:.*]]: memref<8x8xf32>,
 // CHECK-SAME:      %[[IDX:.*]]: index) -> vector<4xf32> {
-// CHECK-NEXT:      %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>
-// CHECK-NEXT:      %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4xf32>
+// CHECK-NEXT:      %[[LOAD:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<1xf32>
+// CHECK-NEXT:      %[[RES:.*]] = vector.broadcast %[[LOAD]] : vector<1xf32> to vector<4xf32>
 // CHECK-NEXT:      return %[[RES]] : vector<4xf32>
 // CHECK-NEXT:    }
 
@@ -208,8 +205,7 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32>, %idx : index) -> vector
 // CHECK-LABEL:   func @transfer_scalar(
 // CHECK-SAME:      %[[MEM:.*]]: memref<?x?xf32>,
 // CHECK-SAME:      %[[IDX:.*]]: index) -> vector<1xf32> {
-// CHECK-NEXT:      %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<?x?xf32>
-// CHECK-NEXT:      %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<1xf32>
+// CHECK-NEXT:      %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<?x?xf32>, vector<1xf32>
 // CHECK-NEXT:      return %[[RES]] : vector<1xf32>
 // CHECK-NEXT:    }
 func.func @transfer_scalar(%mem : memref<?x?xf32>, %idx : index) -> vector<1xf32> {
@@ -222,8 +218,8 @@ func.func @transfer_scalar(%mem : memref<?x?xf32>, %idx : index) -> vector<1xf32
 // CHECK-LABEL:   func @transfer_broadcasting_2D(
 // CHECK-SAME:      %[[MEM:.*]]: memref<8x8xf32>,
 // CHECK-SAME:      %[[IDX:.*]]: index) -> vector<4x4xf32> {
-// CHECK-NEXT:      %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>
-// CHECK-NEXT:      %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4x4xf32>
+// CHECK-NEXT:      %[[LOAD:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<1x1xf32>
+// CHECK-NEXT:      %[[RES:.*]] = vector.broadcast %[[LOAD]] : vector<1x1xf32> to vector<4x4xf32>
 // CHECK-NEXT:      return %[[RES]] : vector<4x4xf32>
 // CHECK-NEXT:    }
 
@@ -322,8 +318,8 @@ func.func @transfer_read_permutations(%mem_0 : memref<?x?xf32>, %mem_1 : memref<
 // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>
 
   %6 = vector.transfer_read %mem_0[%c0, %c0], %cst {in_bounds = [true], permutation_map = #map6} : memref<?x?xf32>, vector<8xf32>
-// CHECK: memref.load %{{.*}}[%[[C0]], %[[C0]]] : memref<?x?xf32>
-// CHECK: vector.broadcast %{{.*}} : f32 to vector<8xf32>
+// CHECK: vector.load %{{.*}}[%[[C0]], %[[C0]]] : memref<?x?xf32>, vector<1xf32>
+// CHECK: vector.broadcast %{{.*}} : vector<1xf32> to vector<8xf32>
 
   return %0, %1, %2, %3, %4, %5, %6 : vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
          vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,

``````````

</details>


https://github.com/llvm/llvm-project/pull/121454


More information about the Mlir-commits mailing list