[Mlir-commits] [mlir] 14726cd - [mlir][Vector] Extend xfer_read(extract)->scalar load to support multiple uses

Diego Caballero llvmlistbot at llvm.org
Fri May 19 14:06:37 PDT 2023


Author: Diego Caballero
Date: 2023-05-19T21:03:18Z
New Revision: 14726cd691517f8d03491a3bf6ad0b338fabba1b

URL: https://github.com/llvm/llvm-project/commit/14726cd691517f8d03491a3bf6ad0b338fabba1b
DIFF: https://github.com/llvm/llvm-project/commit/14726cd691517f8d03491a3bf6ad0b338fabba1b.diff

LOG: [mlir][Vector] Extend xfer_read(extract)->scalar load to support multiple uses

This patch extends the vector.extract(vector.transfer_read) -> scalar
load patterns to support vector.transfer_read with multiple uses. For
now, we check that all the uses are vector.extract operations.
Supporting multiple uses is predicated under a flag.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D150812

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
    mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
    mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index e95c95516c128..292398a3dc5a7 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -112,9 +112,12 @@ void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns,
                                           PatternBenefit benefit = 1);
 
 /// Collects patterns that lower scalar vector transfer ops to memref loads and
-/// stores when beneficial.
+/// stores when beneficial. If `allowMultipleUses` is set to true, the patterns
+/// are applied to vector transfer reads with any number of uses. Otherwise,
+/// only vector transfer reads with a single use will be lowered.
 void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns,
-                                                  PatternBenefit benefit = 1);
+                                                  PatternBenefit benefit,
+                                                  bool allowMultipleUses);
 
 /// Populate the pattern set with the following patterns:
 ///

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 68d8c92a94df4..af0fcd097028d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -561,27 +561,35 @@ class FlattenContiguousRowMajorTransferWritePattern
   }
 };
 
-/// Rewrite extractelement(transfer_read) to memref.load.
-///
-/// Rewrite only if the extractelement op is the single user of the transfer op.
-/// E.g., do not rewrite IR such as:
-/// %0 = vector.transfer_read ... : vector<1024xf32>
-/// %1 = vector.extractelement %0[%a : index] : vector<1024xf32>
-/// %2 = vector.extractelement %0[%b : index] : vector<1024xf32>
-/// Rewriting such IR (replacing one vector load with multiple scalar loads) may
-/// negatively affect performance.
-class RewriteScalarExtractElementOfTransferRead
-    : public OpRewritePattern<vector::ExtractElementOp> {
-  using OpRewritePattern::OpRewritePattern;
+/// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
+/// to `memref.load` patterns. The `match` method is shared for both
+/// `vector.extract` and `vector.extract_element`.
+template <class VectorExtractOp>
+class RewriteScalarExtractOfTransferReadBase
+    : public OpRewritePattern<VectorExtractOp> {
+  using Base = OpRewritePattern<VectorExtractOp>;
 
-  LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp,
-                                PatternRewriter &rewriter) const override {
-    auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
+public:
+  RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
+                                         PatternBenefit benefit,
+                                         bool allowMultipleUses)
+      : Base::OpRewritePattern(context, benefit),
+        allowMultipleUses(allowMultipleUses) {}
+
+  LogicalResult match(VectorExtractOp extractOp) const override {
+    auto xferOp =
+        extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
     if (!xferOp)
       return failure();
-    // xfer result must have a single use. Otherwise, it may be better to
-    // perform a vector load.
-    if (!extractOp.getVector().hasOneUse())
+    // If multiple uses are not allowed, check if xfer has a single use.
+    if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
+      return failure();
+    // If multiple uses are allowed, check if all the xfer uses are extract ops.
+    if (allowMultipleUses &&
+        !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
+          return isa<vector::ExtractOp, vector::ExtractElementOp>(
+              use.getOwner());
+        }))
       return failure();
     // Mask not supported.
     if (xferOp.getMask())
@@ -589,11 +597,32 @@ class RewriteScalarExtractElementOfTransferRead
     // Map not supported.
     if (!xferOp.getPermutationMap().isMinorIdentity())
       return failure();
-    // Cannot rewrite if the indices may be out of bounds. The starting point is
-    // always inbounds, so we don't care in case of 0d transfers.
-    if (xferOp.hasOutOfBoundsDim() && xferOp.getType().getRank() > 0)
+    // Cannot rewrite if the indices may be out of bounds.
+    if (xferOp.hasOutOfBoundsDim())
       return failure();
+    return success();
+  }
+
+private:
+  bool allowMultipleUses;
+};
+
+/// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
+///
+/// All the users of the transfer op must be either `vector.extractelement` or
+/// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
+/// transfer ops with any number of users. Otherwise, rewrite only if the
+/// extract op is the single user of the transfer op. Rewriting a single
+/// vector load with multiple scalar loads may negatively affect performance.
+class RewriteScalarExtractElementOfTransferRead
+    : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
+  using RewriteScalarExtractOfTransferReadBase::
+      RewriteScalarExtractOfTransferReadBase;
+
+  void rewrite(vector::ExtractElementOp extractOp,
+               PatternRewriter &rewriter) const override {
     // Construct scalar load.
+    auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
                                   xferOp.getIndices().end());
     if (extractOp.getPosition()) {
@@ -617,46 +646,26 @@ class RewriteScalarExtractElementOfTransferRead
       rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
           extractOp, xferOp.getSource(), newIndices);
     }
-    return success();
   }
 };
 
-/// Rewrite extract(transfer_read) to memref.load.
+/// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
+/// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
 ///
-/// Rewrite only if the extractelement op is the single user of the transfer op.
-/// E.g., do not rewrite IR such as:
-/// %0 = vector.transfer_read ... : vector<1024xf32>
-/// %1 = vector.extract %0[0] : vector<1024xf32>
-/// %2 = vector.extract %0[5] : vector<1024xf32>
-/// Rewriting such IR (replacing one vector load with multiple scalar loads) may
-/// negatively affect performance.
+/// All the users of the transfer op must be either `vector.extractelement` or
+/// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
+/// transfer ops with any number of users. Otherwise, rewrite only if the
+/// extract op is the single user of the transfer op. Rewriting a single
+/// vector load with multiple scalar loads may negatively affect performance.
 class RewriteScalarExtractOfTransferRead
-    : public OpRewritePattern<vector::ExtractOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
+  using RewriteScalarExtractOfTransferReadBase::
+      RewriteScalarExtractOfTransferReadBase;
 
-  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
-                                PatternRewriter &rewriter) const override {
-    // Only match scalar extracts.
-    if (isa<VectorType>(extractOp.getType()))
-      return failure();
-    auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
-    if (!xferOp)
-      return failure();
-    // xfer result must have a single use. Otherwise, it may be better to
-    // perform a vector load.
-    if (!extractOp.getVector().hasOneUse())
-      return failure();
-    // Mask not supported.
-    if (xferOp.getMask())
-      return failure();
-    // Map not supported.
-    if (!xferOp.getPermutationMap().isMinorIdentity())
-      return failure();
-    // Cannot rewrite if the indices may be out of bounds. The starting point is
-    // always inbounds, so we don't care in case of 0d transfers.
-    if (xferOp.hasOutOfBoundsDim() && xferOp.getType().getRank() > 0)
-      return failure();
+  void rewrite(vector::ExtractOp extractOp,
+               PatternRewriter &rewriter) const override {
     // Construct scalar load.
+    auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
                                   xferOp.getIndices().end());
     for (const auto &it : llvm::enumerate(extractOp.getPosition())) {
@@ -680,7 +689,6 @@ class RewriteScalarExtractOfTransferRead
       rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
           extractOp, xferOp.getSource(), newIndices);
     }
-    return success();
   }
 };
 
@@ -744,10 +752,12 @@ void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
 }
 
 void mlir::vector::populateScalarVectorTransferLoweringPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
+    RewritePatternSet &patterns, PatternBenefit benefit,
+    bool allowMultipleUses) {
   patterns.add<RewriteScalarExtractElementOfTransferRead,
-               RewriteScalarExtractOfTransferRead, RewriteScalarWrite>(
-      patterns.getContext(), benefit);
+               RewriteScalarExtractOfTransferRead>(patterns.getContext(),
+                                                   benefit, allowMultipleUses);
+  patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateVectorTransferDropUnitDimsPatterns(

diff  --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index f3fd6e0d25cb5..7029dc717ca4d 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s -test-scalar-vector-transfer-lowering -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-scalar-vector-transfer-lowering=allow-multiple-uses -split-input-file | FileCheck %s --check-prefix=MULTIUSE
 
 // CHECK-LABEL: func @transfer_read_0d(
 //  CHECK-SAME:     %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index
@@ -108,3 +109,30 @@ func.func @transfer_write_arith_constant(%m: memref<?x?x?xf32>, %idx: index) {
   vector.transfer_write %cst, %m[%idx, %idx, %idx] : vector<1x1xf32>, memref<?x?x?xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_multi_use(
+//  CHECK-SAME:   %[[m:.*]]: memref<?xf32>, %[[idx:.*]]: index
+//   CHECK-NOT:   memref.load
+//       CHECK:   %[[r:.*]] = vector.transfer_read %[[m]][%[[idx]]]
+//       CHECK:   %[[e0:.*]] = vector.extract %[[r]][0]
+//       CHECK:   %[[e1:.*]] = vector.extract %[[r]][1]
+//       CHECK:   return %[[e0]], %[[e1]]
+
+// MULTIUSE-LABEL: func @transfer_read_multi_use(
+//  MULTIUSE-SAME:   %[[m:.*]]: memref<?xf32>, %[[idx0:.*]]: index
+//   MULTIUSE-NOT:   vector.transfer_read
+//       MULTIUSE:   %[[r0:.*]] = memref.load %[[m]][%[[idx0]]
+//       MULTIUSE:   %[[idx1:.*]] = affine.apply
+//       MULTIUSE:   %[[r1:.*]] = memref.load %[[m]][%[[idx1]]
+//       MULTIUSE:   return %[[r0]], %[[r1]]
+
+func.func @transfer_read_multi_use(%m: memref<?xf32>, %idx: index) -> (f32, f32) {
+  %cst = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %m[%idx], %cst {in_bounds = [true]} : memref<?xf32>, vector<16xf32>
+  %1 = vector.extract %0[0] : vector<16xf32>
+  %2 = vector.extract %0[1] : vector<16xf32>
+  return %1, %2 : f32, f32
+}
+

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 50dfeff635ccf..3b0cf2f83f198 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -298,23 +298,33 @@ struct TestScalarVectorTransferLoweringPatterns
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
       TestScalarVectorTransferLoweringPatterns)
 
+  TestScalarVectorTransferLoweringPatterns() = default;
+  TestScalarVectorTransferLoweringPatterns(
+      const TestScalarVectorTransferLoweringPatterns &pass)
+      : PassWrapper(pass) {}
+
   StringRef getArgument() const final {
     return "test-scalar-vector-transfer-lowering";
   }
   StringRef getDescription() const final {
     return "Test lowering of scalar vector transfers to memref loads/stores.";
   }
-  TestScalarVectorTransferLoweringPatterns() = default;
 
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<affine::AffineDialect, memref::MemRefDialect,
                     tensor::TensorDialect, vector::VectorDialect>();
   }
 
+  Option<bool> allowMultipleUses{
+      *this, "allow-multiple-uses",
+      llvm::cl::desc("Fold transfer operations with multiple uses"),
+      llvm::cl::init(false)};
+
   void runOnOperation() override {
     MLIRContext *ctx = &getContext();
     RewritePatternSet patterns(ctx);
-    vector::populateScalarVectorTransferLoweringPatterns(patterns);
+    vector::populateScalarVectorTransferLoweringPatterns(
+        patterns, /*benefit=*/1, allowMultipleUses.getValue());
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };


        


More information about the Mlir-commits mailing list