[Mlir-commits] [mlir] 41e731f - [mlir][vector] Add additional scalar vector transfer foldings
Matthias Springer
llvmlistbot at llvm.org
Thu Dec 22 05:43:58 PST 2022
Author: Matthias Springer
Date: 2022-12-22T14:43:48+01:00
New Revision: 41e731f2a4328f37b5d4e14d912814493e206f32
URL: https://github.com/llvm/llvm-project/commit/41e731f2a4328f37b5d4e14d912814493e206f32
DIFF: https://github.com/llvm/llvm-project/commit/41e731f2a4328f37b5d4e14d912814493e206f32.diff
LOG: [mlir][vector] Add additional scalar vector transfer foldings
* Rewrite vector.transfer_write of vectors with 1 element to
memref.store
* Rewrite vector.extract(vector.transfer_read) to memref.load
Differential Revision: https://reviews.llvm.org/D140391
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 727a356210a38..38062b9893f1a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -567,7 +567,7 @@ class FlattenContiguousRowMajorTransferWritePattern
/// %2 = vector.extractelement %0[%b : index] : vector<1024xf32>
/// Rewriting such IR (replacing one vector load with multiple scalar loads) may
/// negatively affect performance.
-class FoldScalarExtractOfTransferRead
+class RewriteScalarExtractElementOfTransferRead
: public OpRewritePattern<vector::ExtractElementOp> {
using OpRewritePattern::OpRewritePattern;
@@ -618,17 +618,79 @@ class FoldScalarExtractOfTransferRead
}
};
-/// Rewrite scalar transfer_write(broadcast) to memref.store.
-class FoldScalarTransferWriteOfBroadcast
- : public OpRewritePattern<vector::TransferWriteOp> {
+/// Rewrite extract(transfer_read) to memref.load.
+///
+/// Rewrite only if the extractelement op is the single user of the transfer op.
+/// E.g., do not rewrite IR such as:
+/// %0 = vector.transfer_read ... : vector<1024xf32>
+/// %1 = vector.extract %0[0] : vector<1024xf32>
+/// %2 = vector.extract %0[5] : vector<1024xf32>
+/// Rewriting such IR (replacing one vector load with multiple scalar loads) may
+/// negatively affect performance.
+class RewriteScalarExtractOfTransferRead
+ : public OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ // Only match scalar extracts.
+ if (extractOp.getType().isa<VectorType>())
+ return failure();
+ auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
+ if (!xferOp)
+ return failure();
+ // xfer result must have a single use. Otherwise, it may be better to
+ // perform a vector load.
+ if (!extractOp.getVector().hasOneUse())
+ return failure();
+ // Mask not supported.
+ if (xferOp.getMask())
+ return failure();
+ // Map not supported.
+ if (!xferOp.getPermutationMap().isMinorIdentity())
+ return failure();
+ // Cannot rewrite if the indices may be out of bounds. The starting point is
+ // always inbounds, so we don't care in case of 0d transfers.
+ if (xferOp.hasOutOfBoundsDim() && xferOp.getType().getRank() > 0)
+ return failure();
+ // Construct scalar load.
+ SmallVector<Value> newIndices(xferOp.getIndices().begin(),
+ xferOp.getIndices().end());
+ for (const auto &it : llvm::enumerate(extractOp.getPosition())) {
+ int64_t offset = it.value().cast<IntegerAttr>().getInt();
+ int64_t idx =
+ newIndices.size() - extractOp.getPosition().size() + it.index();
+ OpFoldResult ofr = makeComposedFoldedAffineApply(
+ rewriter, extractOp.getLoc(),
+ rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
+ if (ofr.is<Value>()) {
+ newIndices[idx] = ofr.get<Value>();
+ } else {
+ newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
+ extractOp.getLoc(), *getConstantIntValue(ofr));
+ }
+ }
+ if (xferOp.getSource().getType().isa<MemRefType>()) {
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
+ newIndices);
+ } else {
+ rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
+ extractOp, xferOp.getSource(), newIndices);
+ }
+ return success();
+ }
+};
+
+/// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
+/// to memref.store.
+class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
PatternRewriter &rewriter) const override {
// Must be a scalar write.
auto vecType = xferOp.getVectorType();
- if (vecType.getRank() != 0 &&
- (vecType.getRank() != 1 || vecType.getShape()[0] != 1))
+ if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
return failure();
// Mask not supported.
if (xferOp.getMask())
@@ -636,19 +698,25 @@ class FoldScalarTransferWriteOfBroadcast
// Map not supported.
if (!xferOp.getPermutationMap().isMinorIdentity())
return failure();
- // Must be a broadcast of a scalar.
- auto broadcastOp = xferOp.getVector().getDefiningOp<vector::BroadcastOp>();
- if (!broadcastOp || broadcastOp.getSource().getType().isa<VectorType>())
- return failure();
+ // Only float and integer element types are supported.
+ Value scalar;
+ if (vecType.getRank() == 0) {
+ // vector.extract does not support vector<f32> etc., so use
+ // vector.extractelement instead.
+ scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
+ xferOp.getVector());
+ } else {
+ SmallVector<int64_t> pos(vecType.getRank(), 0);
+ scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
+ xferOp.getVector(), pos);
+ }
// Construct a scalar store.
if (xferOp.getSource().getType().isa<MemRefType>()) {
rewriter.replaceOpWithNewOp<memref::StoreOp>(
- xferOp, broadcastOp.getSource(), xferOp.getSource(),
- xferOp.getIndices());
+ xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
} else {
rewriter.replaceOpWithNewOp<tensor::InsertOp>(
- xferOp, broadcastOp.getSource(), xferOp.getSource(),
- xferOp.getIndices());
+ xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
}
return success();
}
@@ -673,9 +741,9 @@ void mlir::vector::transferOpflowOpt(Operation *rootOp) {
void mlir::vector::populateScalarVectorTransferLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns
- .add<FoldScalarExtractOfTransferRead, FoldScalarTransferWriteOfBroadcast>(
- patterns.getContext(), benefit);
+ patterns.add<RewriteScalarExtractElementOfTransferRead,
+ RewriteScalarExtractOfTransferRead, RewriteScalarWrite>(
+ patterns.getContext(), benefit);
}
void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index d34b9c3091f69..f3fd6e0d25cb5 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -44,7 +44,9 @@ func.func @tensor_transfer_read_0d(%t: tensor<?x?x?xf32>, %idx: index) -> f32 {
// CHECK-LABEL: func @transfer_write_0d(
// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
-// CHECK: memref.store %[[f]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+// CHECK: %[[bc:.*]] = vector.broadcast %[[f]] : f32 to vector<f32>
+// CHECK: %[[extract:.*]] = vector.extractelement %[[bc]][] : vector<f32>
+// CHECK: memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
func.func @transfer_write_0d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
%0 = vector.broadcast %f : f32 to vector<f32>
vector.transfer_write %0, %m[%idx, %idx, %idx] : vector<f32>, memref<?x?x?xf32>
@@ -66,10 +68,43 @@ func.func @transfer_write_1d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
// CHECK-LABEL: func @tensor_transfer_write_0d(
// CHECK-SAME: %[[t:.*]]: tensor<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
-// CHECK: %[[r:.*]] = tensor.insert %[[f]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]]
+// CHECK: %[[bc:.*]] = vector.broadcast %[[f]] : f32 to vector<f32>
+// CHECK: %[[extract:.*]] = vector.extractelement %[[bc]][] : vector<f32>
+// CHECK: %[[r:.*]] = tensor.insert %[[extract]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]]
// CHECK: return %[[r]]
func.func @tensor_transfer_write_0d(%t: tensor<?x?x?xf32>, %idx: index, %f: f32) -> tensor<?x?x?xf32> {
%0 = vector.broadcast %f : f32 to vector<f32>
%1 = vector.transfer_write %0, %t[%idx, %idx, %idx] : vector<f32>, tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 8)>
+// CHECK: #[[$map1:.*]] = affine_map<()[s0] -> (s0 + 1)>
+// CHECK-LABEL: func @transfer_read_2d_extract(
+// CHECK-SAME: %[[m:.*]]: memref<?x?x?x?xf32>, %[[idx:.*]]: index, %[[idx2:.*]]: index
+// CHECK: %[[added:.*]] = affine.apply #[[$map]]()[%[[idx]]]
+// CHECK: %[[added1:.*]] = affine.apply #[[$map1]]()[%[[idx]]]
+// CHECK: %[[r:.*]] = memref.load %[[m]][%[[idx]], %[[idx]], %[[added]], %[[added1]]]
+// CHECK: return %[[r]]
+func.func @transfer_read_2d_extract(%m: memref<?x?x?x?xf32>, %idx: index, %idx2: index) -> f32 {
+ %cst = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ %0 = vector.transfer_read %m[%idx, %idx, %idx, %idx], %cst {in_bounds = [true, true]} : memref<?x?x?x?xf32>, vector<10x5xf32>
+ %1 = vector.extract %0[8, 1] : vector<10x5xf32>
+ return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_write_arith_constant(
+// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index
+// CHECK: %[[cst:.*]] = arith.constant dense<5.000000e+00> : vector<1x1xf32>
+// CHECK: %[[extract:.*]] = vector.extract %[[cst]][0, 0] : vector<1x1xf32>
+// CHECK: memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+func.func @transfer_write_arith_constant(%m: memref<?x?x?xf32>, %idx: index) {
+ %cst = arith.constant dense<5.000000e+00> : vector<1x1xf32>
+ vector.transfer_write %cst, %m[%idx, %idx, %idx] : vector<1x1xf32>, memref<?x?x?xf32>
+ return
+}
More information about the Mlir-commits
mailing list