[Mlir-commits] [mlir] [mlir][vector] Relax strides check for 1-element vector load/stores (PR #108998)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 17 08:46:54 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Ivan Butygin (Hardcode84)
<details>
<summary>Changes</summary>
Single elememst vector load/stores are equivalent to scalar load/stores, so they don't need memref to be contigious.
---
Full diff: https://github.com/llvm/llvm-project/pull/108998.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+9-2)
- (modified) mlir/test/Dialect/Vector/ops.mlir (+20)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d3aef4ac38af03..816447713de417 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4769,7 +4769,14 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
//===----------------------------------------------------------------------===//
static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
+ VectorType vecTy,
MemRefType memRefTy) {
+ // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
+ // need any strides limitations.
+ if (!vecTy.isScalable() &&
+ (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
+ return success();
+
if (!isLastMemrefDimUnitStride(memRefTy))
return op->emitOpError("most minor memref dim must have unit stride");
return success();
@@ -4779,7 +4786,7 @@ LogicalResult vector::LoadOp::verify() {
VectorType resVecTy = getVectorType();
MemRefType memRefTy = getMemRefType();
- if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
+ if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
return failure();
// Checks for vector memrefs.
@@ -4811,7 +4818,7 @@ LogicalResult vector::StoreOp::verify() {
VectorType valueVecTy = getVectorType();
MemRefType memRefTy = getMemRefType();
- if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
+ if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
return failure();
// Checks for vector memrefs.
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 4759fcc9511fb2..08d1a189231bcc 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -732,6 +732,26 @@ func.func @vector_load_and_store_0d_scalar_memref(%memref : memref<200x100xf32>,
return
}
+// CHECK-LABEL: @vector_load_and_store_0d_scalar_strided_memref
+func.func @vector_load_and_store_0d_scalar_strided_memref(%memref : memref<200x100xf32, strided<[?, ?], offset: ?>>,
+ %i : index, %j : index) {
+ // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+ %0 = vector.load %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+ // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+ vector.store %0, %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+ return
+}
+
+// CHECK-LABEL: @vector_load_and_store_unit_vec_strided_memref
+func.func @vector_load_and_store_unit_vec_strided_memref(%memref : memref<200x100xf32, strided<[?, ?], offset: ?>>,
+ %i : index, %j : index) {
+ // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+ %0 = vector.load %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+ // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+ vector.store %0, %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+ return
+}
+
// CHECK-LABEL: @vector_load_and_store_1d_scalar_memref
func.func @vector_load_and_store_1d_scalar_memref(%memref : memref<200x100xf32>,
%i : index, %j : index) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/108998
More information about the Mlir-commits
mailing list