[Mlir-commits] [mlir] f325085 - [mlir][vector] Relax strides check for 1-element vector load/stores (#108998)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 19 03:12:36 PDT 2024


Author: Ivan Butygin
Date: 2024-09-19T13:12:32+03:00
New Revision: f3250858780b37a188875c87e57133f1192b2e60

URL: https://github.com/llvm/llvm-project/commit/f3250858780b37a188875c87e57133f1192b2e60
DIFF: https://github.com/llvm/llvm-project/commit/f3250858780b37a188875c87e57133f1192b2e60.diff

LOG: [mlir][vector] Relax strides check for 1-element vector load/stores (#108998)

Single elememst vector load/stores are equivalent to scalar load/stores,
so they don't need memref to be contigious.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d3aef4ac38af03..816447713de417 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4769,7 +4769,14 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
 //===----------------------------------------------------------------------===//
 
 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
+                                                 VectorType vecTy,
                                                  MemRefType memRefTy) {
+  // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
+  // need any strides limitations.
+  if (!vecTy.isScalable() &&
+      (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
+    return success();
+
   if (!isLastMemrefDimUnitStride(memRefTy))
     return op->emitOpError("most minor memref dim must have unit stride");
   return success();
@@ -4779,7 +4786,7 @@ LogicalResult vector::LoadOp::verify() {
   VectorType resVecTy = getVectorType();
   MemRefType memRefTy = getMemRefType();
 
-  if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
+  if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
     return failure();
 
   // Checks for vector memrefs.
@@ -4811,7 +4818,7 @@ LogicalResult vector::StoreOp::verify() {
   VectorType valueVecTy = getVectorType();
   MemRefType memRefTy = getMemRefType();
 
-  if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
+  if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
     return failure();
 
   // Checks for vector memrefs.

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 4759fcc9511fb2..08d1a189231bcc 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -732,6 +732,26 @@ func.func @vector_load_and_store_0d_scalar_memref(%memref : memref<200x100xf32>,
   return
 }
 
+// CHECK-LABEL: @vector_load_and_store_0d_scalar_strided_memref
+func.func @vector_load_and_store_0d_scalar_strided_memref(%memref : memref<200x100xf32, strided<[?, ?], offset: ?>>,
+                                                          %i : index, %j : index) {
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+  vector.store %0, %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+  return
+}
+
+// CHECK-LABEL: @vector_load_and_store_unit_vec_strided_memref
+func.func @vector_load_and_store_unit_vec_strided_memref(%memref : memref<200x100xf32, strided<[?, ?], offset: ?>>,
+                                                         %i : index, %j : index) {
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+  vector.store %0, %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+  return
+}
+
 // CHECK-LABEL: @vector_load_and_store_1d_scalar_memref
 func.func @vector_load_and_store_1d_scalar_memref(%memref : memref<200x100xf32>,
                                              %i : index, %j : index) {


        


More information about the Mlir-commits mailing list