[Mlir-commits] [mlir] [mlir][vector] Relax strides check for 1-element vector load/stores (PR #108998)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 17 08:46:54 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Ivan Butygin (Hardcode84)

<details>
<summary>Changes</summary>

Single elememst vector load/stores are equivalent to scalar load/stores, so they don't need memref to be contigious.

---
Full diff: https://github.com/llvm/llvm-project/pull/108998.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+9-2) 
- (modified) mlir/test/Dialect/Vector/ops.mlir (+20) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d3aef4ac38af03..816447713de417 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4769,7 +4769,14 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
 //===----------------------------------------------------------------------===//
 
 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
+                                                 VectorType vecTy,
                                                  MemRefType memRefTy) {
+  // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
+  // need any strides limitations.
+  if (!vecTy.isScalable() &&
+      (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
+    return success();
+
   if (!isLastMemrefDimUnitStride(memRefTy))
     return op->emitOpError("most minor memref dim must have unit stride");
   return success();
@@ -4779,7 +4786,7 @@ LogicalResult vector::LoadOp::verify() {
   VectorType resVecTy = getVectorType();
   MemRefType memRefTy = getMemRefType();
 
-  if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
+  if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
     return failure();
 
   // Checks for vector memrefs.
@@ -4811,7 +4818,7 @@ LogicalResult vector::StoreOp::verify() {
   VectorType valueVecTy = getVectorType();
   MemRefType memRefTy = getMemRefType();
 
-  if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
+  if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
     return failure();
 
   // Checks for vector memrefs.
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 4759fcc9511fb2..08d1a189231bcc 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -732,6 +732,26 @@ func.func @vector_load_and_store_0d_scalar_memref(%memref : memref<200x100xf32>,
   return
 }
 
+// CHECK-LABEL: @vector_load_and_store_0d_scalar_strided_memref
+func.func @vector_load_and_store_0d_scalar_strided_memref(%memref : memref<200x100xf32, strided<[?, ?], offset: ?>>,
+                                                          %i : index, %j : index) {
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+  vector.store %0, %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+  return
+}
+
+// CHECK-LABEL: @vector_load_and_store_unit_vec_strided_memref
+func.func @vector_load_and_store_unit_vec_strided_memref(%memref : memref<200x100xf32, strided<[?, ?], offset: ?>>,
+                                                         %i : index, %j : index) {
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+  vector.store %0, %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+  return
+}
+
 // CHECK-LABEL: @vector_load_and_store_1d_scalar_memref
 func.func @vector_load_and_store_1d_scalar_memref(%memref : memref<200x100xf32>,
                                              %i : index, %j : index) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/108998


More information about the Mlir-commits mailing list