[Mlir-commits] [mlir] 47f7938 - [mlir][Vector] Add support for lowering 0-d transfers to load/store.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Oct 12 05:35:24 PDT 2021


Author: Nicolas Vasilache
Date: 2021-10-12T12:35:19Z
New Revision: 47f7938a948591671db2e64d2b833fcd8d5fafda

URL: https://github.com/llvm/llvm-project/commit/47f7938a948591671db2e64d2b833fcd8d5fafda
DIFF: https://github.com/llvm/llvm-project/commit/47f7938a948591671db2e64d2b833fcd8d5fafda.diff

LOG: [mlir][Vector] Add support for lowering 0-d transfers to load/store.

Reviewed By: pifon2a

Differential Revision: https://reviews.llvm.org/D111603

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index c76c43afbed3..46865160e6f9 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2590,6 +2590,24 @@ struct VectorLoadToMemrefLoadLowering
   }
 };
 
+/// Replace a scalar vector.store with a memref.store.
+struct VectorStoreToMemrefStoreLowering
+    : public OpRewritePattern<vector::StoreOp> {
+  using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
+                                PatternRewriter &rewriter) const override {
+    auto vecType = storeOp.getVectorType();
+    if (vecType.getNumElements() != 1)
+      return failure();
+    Value extracted = rewriter.create<vector::ExtractOp>(
+        storeOp.getLoc(), storeOp.valueToStore(), ArrayRef<int64_t>{1});
+    rewriter.replaceOpWithNewOp<memref::StoreOp>(
+        storeOp, extracted, storeOp.base(), storeOp.indices());
+    return success();
+  }
+};
+
 /// Progressive lowering of transfer_write. This pattern supports lowering of
 /// `vector.transfer_write` to `vector.store` if all of the following hold:
 /// - Stride of most minor memref dimension must be 1.
@@ -2611,7 +2629,7 @@ struct TransferWriteToVectorStoreLowering
       return failure();
     // Permutations are handled by VectorToSCF or
     // populateVectorTransferPermutationMapLoweringPatterns.
-    if (!write.permutation_map().isMinorIdentity())
+    if (!write.isZeroD() && !write.permutation_map().isMinorIdentity())
       return failure();
     auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
     if (!memRefType)
@@ -2766,6 +2784,9 @@ struct TransferWritePermutationLowering
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp op,
                                 PatternRewriter &rewriter) const override {
+    if (op.isZeroD())
+      return failure();
+
     SmallVector<unsigned> permutation;
     AffineMap map = op.permutation_map();
     if (map.isMinorIdentity())
@@ -3581,7 +3602,9 @@ void mlir::vector::populateVectorTransferLoweringPatterns(
   patterns.add<TransferReadToVectorLoadLowering,
                TransferWriteToVectorStoreLowering>(patterns.getContext(),
                                                    maxTransferRank);
-  patterns.add<VectorLoadToMemrefLoadLowering>(patterns.getContext());
+  patterns
+      .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
+          patterns.getContext());
 }
 
 void mlir::vector::populateVectorUnrollPatterns(

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
index 825d7468dc95..c2db8a501d6b 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
@@ -1,5 +1,23 @@
 // RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -canonicalize -split-input-file | FileCheck %s
 
+// CHECK-LABEL: func @vector_transfer_ops_0d(
+//  CHECK-SAME:   %[[MEM:.*]]: memref<f32>) {
+func @vector_transfer_ops_0d(%M: memref<f32>) {
+    %f0 = constant 0.0 : f32
+
+//  CHECK-NEXT:   %[[V:.*]] = memref.load %[[MEM]][] : memref<f32>
+    %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->(0)>} :
+      memref<f32>, vector<1xf32>
+
+//  CHECK-NEXT:   memref.store %[[V]], %[[MEM]][] : memref<f32>
+    vector.transfer_write %0, %M[] {permutation_map = affine_map<()->(0)>} :
+      vector<1xf32>, memref<f32>
+  
+    return
+}
+
+// -----
+
 // transfer_read/write are lowered to vector.load/store
 // CHECK-LABEL:   func @transfer_to_load(
 // CHECK-SAME:                                %[[MEM:.*]]: memref<8x8xf32>,


        


More information about the Mlir-commits mailing list