[Mlir-commits] [mlir] 4a3defa - [mlir][vector] Refactor TransferReadToVectorLoadLowering
Matthias Springer
llvmlistbot at llvm.org
Fri Jul 16 22:00:54 PDT 2021
Author: Matthias Springer
Date: 2021-07-17T13:53:09+09:00
New Revision: 4a3defa6298ae84f88bf0ad76b50a6264ab2f337
URL: https://github.com/llvm/llvm-project/commit/4a3defa6298ae84f88bf0ad76b50a6264ab2f337
DIFF: https://github.com/llvm/llvm-project/commit/4a3defa6298ae84f88bf0ad76b50a6264ab2f337.diff
LOG: [mlir][vector] Refactor TransferReadToVectorLoadLowering
* TransferReadToVectorLoadLowering no longer generates memref.load ops.
* Add new pattern VectorLoadToMemrefLoadLowering that lowers scalar vector.loads to memref.loads.
* Add vector::BroadcastOp canonicalization pattern that folds broadcast chains.
Differential Revision: https://reviews.llvm.org/D106117
Added:
Modified:
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 1674f0e0bef15..9fbc6c3711d8b 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1346,11 +1346,25 @@ class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
}
};
+// Fold broadcast1(broadcast2(x)) into broadcast1(x).
+struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
+ using OpRewritePattern<BroadcastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ auto srcBroadcast = broadcastOp.source().getDefiningOp<BroadcastOp>();
+ if (!srcBroadcast)
+ return failure();
+ rewriter.replaceOpWithNewOp<BroadcastOp>(
+ broadcastOp, broadcastOp.getVectorType(), srcBroadcast.source());
+ return success();
+ }
+};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<BroadcastToShapeCast>(context);
+ results.add<BroadcastToShapeCast, BroadcastFolder>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index c02a79d08bdb5..4bd1ee15dece7 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2510,32 +2510,39 @@ struct TransferReadToVectorLoadLowering
return failure();
if (read.mask())
return failure();
- Operation *loadOp;
- if (!broadcastedDims.empty() &&
- unbroadcastedVectorType.getNumElements() == 1) {
- // If broadcasting is required and the number of loaded elements is 1 then
- // we can create `memref.load` instead of `vector.load`.
- loadOp = rewriter.create<memref::LoadOp>(read.getLoc(), read.source(),
- read.indices());
- } else {
- // Otherwise create `vector.load`.
- loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
- unbroadcastedVectorType,
- read.source(), read.indices());
- }
+ auto loadOp = rewriter.create<vector::LoadOp>(
+ read.getLoc(), unbroadcastedVectorType, read.source(), read.indices());
// Insert a broadcasting op if required.
if (!broadcastedDims.empty()) {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- read, read.getVectorType(), loadOp->getResult(0));
+ read, read.getVectorType(), loadOp.result());
} else {
- rewriter.replaceOp(read, loadOp->getResult(0));
+ rewriter.replaceOp(read, loadOp.result());
}
return success();
}
};
+/// Replace a scalar vector.load with a memref.load.
+struct VectorLoadToMemrefLoadLowering
+ : public OpRewritePattern<vector::LoadOp> {
+ using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::LoadOp loadOp,
+ PatternRewriter &rewriter) const override {
+ auto vecType = loadOp.getVectorType();
+ if (vecType.getNumElements() != 1)
+ return failure();
+ auto memrefLoad = rewriter.create<memref::LoadOp>(
+ loadOp.getLoc(), loadOp.base(), loadOp.indices());
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ loadOp, VectorType::get({1}, vecType.getElementType()), memrefLoad);
+ return success();
+ }
+};
+
/// Progressive lowering of transfer_write. This pattern supports lowering of
/// `vector.transfer_write` to `vector.store` if all of the following hold:
/// - The op writes to a memref with the default layout.
@@ -3674,8 +3681,9 @@ void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
void mlir::vector::populateVectorTransferLoweringPatterns(
RewritePatternSet &patterns) {
- patterns.add<TransferReadToVectorLoadLowering,
- TransferWriteToVectorStoreLowering>(patterns.getContext());
+ patterns
+ .add<TransferReadToVectorLoadLowering, TransferWriteToVectorStoreLowering,
+ VectorLoadToMemrefLoadLowering>(patterns.getContext());
populateVectorTransferPermutationMapLoweringPatterns(patterns);
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index cc2d59a1c0257..81a407459ec45 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -613,6 +613,18 @@ func @broadcast_folding2() -> vector<4x16xi32> {
// -----
+// CHECK-LABEL: @fold_consecutive_broadcasts(
+// CHECK-SAME: %[[ARG0:.*]]: i32
+// CHECK: %[[RESULT:.*]] = vector.broadcast %[[ARG0]] : i32 to vector<4x16xi32>
+// CHECK: return %[[RESULT]]
+func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> {
+ %1 = vector.broadcast %a : i32 to vector<16xi32>
+ %2 = vector.broadcast %1 : vector<16xi32> to vector<4x16xi32>
+ return %2 : vector<4x16xi32>
+}
+
+// -----
+
// CHECK-LABEL: shape_cast_constant
// CHECK-DAG: %[[CST1:.*]] = constant dense<1> : vector<3x4x2xi32>
// CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<20x2xf32>
diff --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
index 60bbadf59874a..931c3ba91774a 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -canonicalize -split-input-file | FileCheck %s
// transfer_read/write are lowered to vector.load/store
// CHECK-LABEL: func @transfer_to_load(
@@ -174,6 +174,21 @@ func @transfer_broadcasting(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32>
// -----
+// CHECK-LABEL: func @transfer_scalar(
+// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> vector<1xf32> {
+// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<?x?xf32>
+// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<1xf32>
+// CHECK-NEXT: return %[[RES]] : vector<1xf32>
+// CHECK-NEXT: }
+func @transfer_scalar(%mem : memref<?x?xf32>, %i : index) -> vector<1xf32> {
+ %cf0 = constant 0.0 : f32
+ %res = vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<1xf32>
+ return %res : vector<1xf32>
+}
+
+// -----
+
// An example with two broadcasted dimensions.
// CHECK-LABEL: func @transfer_broadcasting_2D(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
More information about the Mlir-commits
mailing list