[Mlir-commits] [mlir] f325085 - [mlir][vector] Relax strides check for 1-element vector load/stores (#108998)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 19 03:12:36 PDT 2024
Author: Ivan Butygin
Date: 2024-09-19T13:12:32+03:00
New Revision: f3250858780b37a188875c87e57133f1192b2e60
URL: https://github.com/llvm/llvm-project/commit/f3250858780b37a188875c87e57133f1192b2e60
DIFF: https://github.com/llvm/llvm-project/commit/f3250858780b37a188875c87e57133f1192b2e60.diff
LOG: [mlir][vector] Relax strides check for 1-element vector load/stores (#108998)
Single elememst vector load/stores are equivalent to scalar load/stores,
so they don't need memref to be contigious.
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d3aef4ac38af03..816447713de417 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4769,7 +4769,14 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
//===----------------------------------------------------------------------===//
static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
+ VectorType vecTy,
MemRefType memRefTy) {
+ // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
+ // need any strides limitations.
+ if (!vecTy.isScalable() &&
+ (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
+ return success();
+
if (!isLastMemrefDimUnitStride(memRefTy))
return op->emitOpError("most minor memref dim must have unit stride");
return success();
@@ -4779,7 +4786,7 @@ LogicalResult vector::LoadOp::verify() {
VectorType resVecTy = getVectorType();
MemRefType memRefTy = getMemRefType();
- if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
+ if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
return failure();
// Checks for vector memrefs.
@@ -4811,7 +4818,7 @@ LogicalResult vector::StoreOp::verify() {
VectorType valueVecTy = getVectorType();
MemRefType memRefTy = getMemRefType();
- if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
+ if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
return failure();
// Checks for vector memrefs.
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 4759fcc9511fb2..08d1a189231bcc 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -732,6 +732,26 @@ func.func @vector_load_and_store_0d_scalar_memref(%memref : memref<200x100xf32>,
return
}
+// CHECK-LABEL: @vector_load_and_store_0d_scalar_strided_memref
+func.func @vector_load_and_store_0d_scalar_strided_memref(%memref : memref<200x100xf32, strided<[?, ?], offset: ?>>,
+ %i : index, %j : index) {
+ // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+ %0 = vector.load %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+ // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+ vector.store %0, %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+ return
+}
+
+// CHECK-LABEL: @vector_load_and_store_unit_vec_strided_memref
+func.func @vector_load_and_store_unit_vec_strided_memref(%memref : memref<200x100xf32, strided<[?, ?], offset: ?>>,
+ %i : index, %j : index) {
+ // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+ %0 = vector.load %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+ // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+ vector.store %0, %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+ return
+}
+
// CHECK-LABEL: @vector_load_and_store_1d_scalar_memref
func.func @vector_load_and_store_1d_scalar_memref(%memref : memref<200x100xf32>,
%i : index, %j : index) {
More information about the Mlir-commits
mailing list