[Mlir-commits] [mlir] 47f7938 - [mlir][Vector] Add support for lowering 0-d transfers to load/store.
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Oct 12 05:35:24 PDT 2021
Author: Nicolas Vasilache
Date: 2021-10-12T12:35:19Z
New Revision: 47f7938a948591671db2e64d2b833fcd8d5fafda
URL: https://github.com/llvm/llvm-project/commit/47f7938a948591671db2e64d2b833fcd8d5fafda
DIFF: https://github.com/llvm/llvm-project/commit/47f7938a948591671db2e64d2b833fcd8d5fafda.diff
LOG: [mlir][Vector] Add support for lowering 0-d transfers to load/store.
Reviewed By: pifon2a
Differential Revision: https://reviews.llvm.org/D111603
Added:
Modified:
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index c76c43afbed3..46865160e6f9 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2590,6 +2590,24 @@ struct VectorLoadToMemrefLoadLowering
}
};
+/// Replace a scalar vector.store with a memref.store.
+struct VectorStoreToMemrefStoreLowering
+ : public OpRewritePattern<vector::StoreOp> {
+ using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::StoreOp storeOp,
+ PatternRewriter &rewriter) const override {
+ auto vecType = storeOp.getVectorType();
+ if (vecType.getNumElements() != 1)
+ return failure();
+ Value extracted = rewriter.create<vector::ExtractOp>(
+ storeOp.getLoc(), storeOp.valueToStore(), ArrayRef<int64_t>{1});
+ rewriter.replaceOpWithNewOp<memref::StoreOp>(
+ storeOp, extracted, storeOp.base(), storeOp.indices());
+ return success();
+ }
+};
+
/// Progressive lowering of transfer_write. This pattern supports lowering of
/// `vector.transfer_write` to `vector.store` if all of the following hold:
/// - Stride of most minor memref dimension must be 1.
@@ -2611,7 +2629,7 @@ struct TransferWriteToVectorStoreLowering
return failure();
// Permutations are handled by VectorToSCF or
// populateVectorTransferPermutationMapLoweringPatterns.
- if (!write.permutation_map().isMinorIdentity())
+ if (!write.isZeroD() && !write.permutation_map().isMinorIdentity())
return failure();
auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
if (!memRefType)
@@ -2766,6 +2784,9 @@ struct TransferWritePermutationLowering
LogicalResult matchAndRewrite(vector::TransferWriteOp op,
PatternRewriter &rewriter) const override {
+ if (op.isZeroD())
+ return failure();
+
SmallVector<unsigned> permutation;
AffineMap map = op.permutation_map();
if (map.isMinorIdentity())
@@ -3581,7 +3602,9 @@ void mlir::vector::populateVectorTransferLoweringPatterns(
patterns.add<TransferReadToVectorLoadLowering,
TransferWriteToVectorStoreLowering>(patterns.getContext(),
maxTransferRank);
- patterns.add<VectorLoadToMemrefLoadLowering>(patterns.getContext());
+ patterns
+ .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
+ patterns.getContext());
}
void mlir::vector::populateVectorUnrollPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
index 825d7468dc95..c2db8a501d6b 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
@@ -1,5 +1,23 @@
// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -canonicalize -split-input-file | FileCheck %s
+// CHECK-LABEL: func @vector_transfer_ops_0d(
+// CHECK-SAME: %[[MEM:.*]]: memref<f32>) {
+func @vector_transfer_ops_0d(%M: memref<f32>) {
+ %f0 = constant 0.0 : f32
+
+// CHECK-NEXT: %[[V:.*]] = memref.load %[[MEM]][] : memref<f32>
+ %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->(0)>} :
+ memref<f32>, vector<1xf32>
+
+// CHECK-NEXT: memref.store %[[V]], %[[MEM]][] : memref<f32>
+ vector.transfer_write %0, %M[] {permutation_map = affine_map<()->(0)>} :
+ vector<1xf32>, memref<f32>
+
+ return
+}
+
+// -----
+
// transfer_read/write are lowered to vector.load/store
// CHECK-LABEL: func @transfer_to_load(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
More information about the Mlir-commits
mailing list