[Mlir-commits] [mlir] 14726cd - [mlir][Vector] Extend xfer_read(extract)->scalar load to support multiple uses
Diego Caballero
llvmlistbot at llvm.org
Fri May 19 14:06:37 PDT 2023
Author: Diego Caballero
Date: 2023-05-19T21:03:18Z
New Revision: 14726cd691517f8d03491a3bf6ad0b338fabba1b
URL: https://github.com/llvm/llvm-project/commit/14726cd691517f8d03491a3bf6ad0b338fabba1b
DIFF: https://github.com/llvm/llvm-project/commit/14726cd691517f8d03491a3bf6ad0b338fabba1b.diff
LOG: [mlir][Vector] Extend xfer_read(extract)->scalar load to support multiple uses
This patch extends the vector.extract(vector.transfer_read) -> scalar
load patterns to support vector.transfer_read with multiple uses. For
now, we check that all the uses are vector.extract operations.
Supporting multiple uses is predicated under a flag.
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D150812
Added:
Modified:
mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index e95c95516c128..292398a3dc5a7 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -112,9 +112,12 @@ void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Collects patterns that lower scalar vector transfer ops to memref loads and
-/// stores when beneficial.
+/// stores when beneficial. If `allowMultipleUses` is set to true, the patterns
+/// are applied to vector transfer reads with any number of uses. Otherwise,
+/// only vector transfer reads with a single use will be lowered.
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
+ PatternBenefit benefit,
+ bool allowMultipleUses);
/// Populate the pattern set with the following patterns:
///
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 68d8c92a94df4..af0fcd097028d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -561,27 +561,35 @@ class FlattenContiguousRowMajorTransferWritePattern
}
};
-/// Rewrite extractelement(transfer_read) to memref.load.
-///
-/// Rewrite only if the extractelement op is the single user of the transfer op.
-/// E.g., do not rewrite IR such as:
-/// %0 = vector.transfer_read ... : vector<1024xf32>
-/// %1 = vector.extractelement %0[%a : index] : vector<1024xf32>
-/// %2 = vector.extractelement %0[%b : index] : vector<1024xf32>
-/// Rewriting such IR (replacing one vector load with multiple scalar loads) may
-/// negatively affect performance.
-class RewriteScalarExtractElementOfTransferRead
- : public OpRewritePattern<vector::ExtractElementOp> {
- using OpRewritePattern::OpRewritePattern;
+/// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
+/// to `memref.load` patterns. The `match` method is shared for both
+/// `vector.extract` and `vector.extract_element`.
+template <class VectorExtractOp>
+class RewriteScalarExtractOfTransferReadBase
+ : public OpRewritePattern<VectorExtractOp> {
+ using Base = OpRewritePattern<VectorExtractOp>;
- LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp,
- PatternRewriter &rewriter) const override {
- auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
+public:
+ RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
+ PatternBenefit benefit,
+ bool allowMultipleUses)
+ : Base::OpRewritePattern(context, benefit),
+ allowMultipleUses(allowMultipleUses) {}
+
+ LogicalResult match(VectorExtractOp extractOp) const override {
+ auto xferOp =
+ extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
if (!xferOp)
return failure();
- // xfer result must have a single use. Otherwise, it may be better to
- // perform a vector load.
- if (!extractOp.getVector().hasOneUse())
+ // If multiple uses are not allowed, check if xfer has a single use.
+ if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
+ return failure();
+ // If multiple uses are allowed, check if all the xfer uses are extract ops.
+ if (allowMultipleUses &&
+ !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
+ return isa<vector::ExtractOp, vector::ExtractElementOp>(
+ use.getOwner());
+ }))
return failure();
// Mask not supported.
if (xferOp.getMask())
@@ -589,11 +597,32 @@ class RewriteScalarExtractElementOfTransferRead
// Map not supported.
if (!xferOp.getPermutationMap().isMinorIdentity())
return failure();
- // Cannot rewrite if the indices may be out of bounds. The starting point is
- // always inbounds, so we don't care in case of 0d transfers.
- if (xferOp.hasOutOfBoundsDim() && xferOp.getType().getRank() > 0)
+ // Cannot rewrite if the indices may be out of bounds.
+ if (xferOp.hasOutOfBoundsDim())
return failure();
+ return success();
+ }
+
+private:
+ bool allowMultipleUses;
+};
+
+/// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
+///
+/// All the users of the transfer op must be either `vector.extractelement` or
+/// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
+/// transfer ops with any number of users. Otherwise, rewrite only if the
+/// extract op is the single user of the transfer op. Rewriting a single
+/// vector load with multiple scalar loads may negatively affect performance.
+class RewriteScalarExtractElementOfTransferRead
+ : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
+ using RewriteScalarExtractOfTransferReadBase::
+ RewriteScalarExtractOfTransferReadBase;
+
+ void rewrite(vector::ExtractElementOp extractOp,
+ PatternRewriter &rewriter) const override {
// Construct scalar load.
+ auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
xferOp.getIndices().end());
if (extractOp.getPosition()) {
@@ -617,46 +646,26 @@ class RewriteScalarExtractElementOfTransferRead
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
extractOp, xferOp.getSource(), newIndices);
}
- return success();
}
};
-/// Rewrite extract(transfer_read) to memref.load.
+/// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
+/// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
///
-/// Rewrite only if the extractelement op is the single user of the transfer op.
-/// E.g., do not rewrite IR such as:
-/// %0 = vector.transfer_read ... : vector<1024xf32>
-/// %1 = vector.extract %0[0] : vector<1024xf32>
-/// %2 = vector.extract %0[5] : vector<1024xf32>
-/// Rewriting such IR (replacing one vector load with multiple scalar loads) may
-/// negatively affect performance.
+/// All the users of the transfer op must be either `vector.extractelement` or
+/// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
+/// transfer ops with any number of users. Otherwise, rewrite only if the
+/// extract op is the single user of the transfer op. Rewriting a single
+/// vector load with multiple scalar loads may negatively affect performance.
class RewriteScalarExtractOfTransferRead
- : public OpRewritePattern<vector::ExtractOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
+ using RewriteScalarExtractOfTransferReadBase::
+ RewriteScalarExtractOfTransferReadBase;
- LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
- PatternRewriter &rewriter) const override {
- // Only match scalar extracts.
- if (isa<VectorType>(extractOp.getType()))
- return failure();
- auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
- if (!xferOp)
- return failure();
- // xfer result must have a single use. Otherwise, it may be better to
- // perform a vector load.
- if (!extractOp.getVector().hasOneUse())
- return failure();
- // Mask not supported.
- if (xferOp.getMask())
- return failure();
- // Map not supported.
- if (!xferOp.getPermutationMap().isMinorIdentity())
- return failure();
- // Cannot rewrite if the indices may be out of bounds. The starting point is
- // always inbounds, so we don't care in case of 0d transfers.
- if (xferOp.hasOutOfBoundsDim() && xferOp.getType().getRank() > 0)
- return failure();
+ void rewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
// Construct scalar load.
+ auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
xferOp.getIndices().end());
for (const auto &it : llvm::enumerate(extractOp.getPosition())) {
@@ -680,7 +689,6 @@ class RewriteScalarExtractOfTransferRead
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
extractOp, xferOp.getSource(), newIndices);
}
- return success();
}
};
@@ -744,10 +752,12 @@ void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
}
void mlir::vector::populateScalarVectorTransferLoweringPatterns(
- RewritePatternSet &patterns, PatternBenefit benefit) {
+ RewritePatternSet &patterns, PatternBenefit benefit,
+ bool allowMultipleUses) {
patterns.add<RewriteScalarExtractElementOfTransferRead,
- RewriteScalarExtractOfTransferRead, RewriteScalarWrite>(
- patterns.getContext(), benefit);
+ RewriteScalarExtractOfTransferRead>(patterns.getContext(),
+ benefit, allowMultipleUses);
+ patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
}
void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index f3fd6e0d25cb5..7029dc717ca4d 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -test-scalar-vector-transfer-lowering -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-scalar-vector-transfer-lowering=allow-multiple-uses -split-input-file | FileCheck %s --check-prefix=MULTIUSE
// CHECK-LABEL: func @transfer_read_0d(
// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index
@@ -108,3 +109,30 @@ func.func @transfer_write_arith_constant(%m: memref<?x?x?xf32>, %idx: index) {
vector.transfer_write %cst, %m[%idx, %idx, %idx] : vector<1x1xf32>, memref<?x?x?xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_multi_use(
+// CHECK-SAME: %[[m:.*]]: memref<?xf32>, %[[idx:.*]]: index
+// CHECK-NOT: memref.load
+// CHECK: %[[r:.*]] = vector.transfer_read %[[m]][%[[idx]]]
+// CHECK: %[[e0:.*]] = vector.extract %[[r]][0]
+// CHECK: %[[e1:.*]] = vector.extract %[[r]][1]
+// CHECK: return %[[e0]], %[[e1]]
+
+// MULTIUSE-LABEL: func @transfer_read_multi_use(
+// MULTIUSE-SAME: %[[m:.*]]: memref<?xf32>, %[[idx0:.*]]: index
+// MULTIUSE-NOT: vector.transfer_read
+// MULTIUSE: %[[r0:.*]] = memref.load %[[m]][%[[idx0]]
+// MULTIUSE: %[[idx1:.*]] = affine.apply
+// MULTIUSE: %[[r1:.*]] = memref.load %[[m]][%[[idx1]]
+// MULTIUSE: return %[[r0]], %[[r1]]
+
+func.func @transfer_read_multi_use(%m: memref<?xf32>, %idx: index) -> (f32, f32) {
+ %cst = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %m[%idx], %cst {in_bounds = [true]} : memref<?xf32>, vector<16xf32>
+ %1 = vector.extract %0[0] : vector<16xf32>
+ %2 = vector.extract %0[1] : vector<16xf32>
+ return %1, %2 : f32, f32
+}
+
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 50dfeff635ccf..3b0cf2f83f198 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -298,23 +298,33 @@ struct TestScalarVectorTransferLoweringPatterns
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestScalarVectorTransferLoweringPatterns)
+ TestScalarVectorTransferLoweringPatterns() = default;
+ TestScalarVectorTransferLoweringPatterns(
+ const TestScalarVectorTransferLoweringPatterns &pass)
+ : PassWrapper(pass) {}
+
StringRef getArgument() const final {
return "test-scalar-vector-transfer-lowering";
}
StringRef getDescription() const final {
return "Test lowering of scalar vector transfers to memref loads/stores.";
}
- TestScalarVectorTransferLoweringPatterns() = default;
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<affine::AffineDialect, memref::MemRefDialect,
tensor::TensorDialect, vector::VectorDialect>();
}
+ Option<bool> allowMultipleUses{
+ *this, "allow-multiple-uses",
+ llvm::cl::desc("Fold transfer operations with multiple uses"),
+ llvm::cl::init(false)};
+
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
- vector::populateScalarVectorTransferLoweringPatterns(patterns);
+ vector::populateScalarVectorTransferLoweringPatterns(
+ patterns, /*benefit=*/1, allowMultipleUses.getValue());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
More information about the Mlir-commits
mailing list