[Mlir-commits] [mlir] 2ec98ff - [mlir][vector] Add scalar vector xfer to memref patterns
Matthias Springer
llvmlistbot at llvm.org
Mon Dec 19 01:32:04 PST 2022
Author: Matthias Springer
Date: 2022-12-19T10:27:49+01:00
New Revision: 2ec98ffbf12163ee4ff9f4e674eba714bce24ec1
URL: https://github.com/llvm/llvm-project/commit/2ec98ffbf12163ee4ff9f4e674eba714bce24ec1
DIFF: https://github.com/llvm/llvm-project/commit/2ec98ffbf12163ee4ff9f4e674eba714bce24ec1.diff
LOG: [mlir][vector] Add scalar vector xfer to memref patterns
These patterns devectorize scalar transfers such as vector<f32> or vector<1xf32>.
Differential Revision: https://reviews.llvm.org/D140215
Added:
mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index a4735aefd61c2..0028abee51c27 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -142,6 +142,11 @@ void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns,
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Collects patterns that lower scalar vector transfer ops to memref loads and
+/// stores when beneficial.
+void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index cfbf289b94bdf..6fb1b8c18a122 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
MLIRMemRefDialect
MLIRSCFDialect
MLIRSideEffectInterfaces
+ MLIRTensorDialect
MLIRTransforms
MLIRVectorDialect
MLIRVectorInterfaces
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index b59b10c43678b..727a356210a38 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -11,8 +11,10 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
@@ -556,6 +558,101 @@ class FlattenContiguousRowMajorTransferWritePattern
}
};
+/// Rewrite extractelement(transfer_read) to memref.load.
+///
+/// Rewrite only if the extractelement op is the single user of the transfer op.
+/// E.g., do not rewrite IR such as:
+/// %0 = vector.transfer_read ... : vector<1024xf32>
+/// %1 = vector.extractelement %0[%a : index] : vector<1024xf32>
+/// %2 = vector.extractelement %0[%b : index] : vector<1024xf32>
+/// Rewriting such IR (replacing one vector load with multiple scalar loads) may
+/// negatively affect performance.
+class FoldScalarExtractOfTransferRead
+ : public OpRewritePattern<vector::ExtractElementOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
+ if (!xferOp)
+ return failure();
+ // xfer result must have a single use. Otherwise, it may be better to
+ // perform a vector load.
+ if (!extractOp.getVector().hasOneUse())
+ return failure();
+ // Mask not supported.
+ if (xferOp.getMask())
+ return failure();
+ // Map not supported.
+ if (!xferOp.getPermutationMap().isMinorIdentity())
+ return failure();
+ // Cannot rewrite if the indices may be out of bounds. The starting point is
+ // always inbounds, so we don't care in case of 0d transfers.
+ if (xferOp.hasOutOfBoundsDim() && xferOp.getType().getRank() > 0)
+ return failure();
+ // Construct scalar load.
+ SmallVector<Value> newIndices(xferOp.getIndices().begin(),
+ xferOp.getIndices().end());
+ if (extractOp.getPosition()) {
+ AffineExpr sym0, sym1;
+ bindSymbols(extractOp.getContext(), sym0, sym1);
+ OpFoldResult ofr = makeComposedFoldedAffineApply(
+ rewriter, extractOp.getLoc(), sym0 + sym1,
+ {newIndices[newIndices.size() - 1], extractOp.getPosition()});
+ if (ofr.is<Value>()) {
+ newIndices[newIndices.size() - 1] = ofr.get<Value>();
+ } else {
+ newIndices[newIndices.size() - 1] =
+ rewriter.create<arith::ConstantIndexOp>(extractOp.getLoc(),
+ *getConstantIntValue(ofr));
+ }
+ }
+ if (xferOp.getSource().getType().isa<MemRefType>()) {
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
+ newIndices);
+ } else {
+ rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
+ extractOp, xferOp.getSource(), newIndices);
+ }
+ return success();
+ }
+};
+
+/// Rewrite scalar transfer_write(broadcast) to memref.store.
+class FoldScalarTransferWriteOfBroadcast
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
+ PatternRewriter &rewriter) const override {
+ // Must be a scalar write.
+ auto vecType = xferOp.getVectorType();
+ if (vecType.getRank() != 0 &&
+ (vecType.getRank() != 1 || vecType.getShape()[0] != 1))
+ return failure();
+ // Mask not supported.
+ if (xferOp.getMask())
+ return failure();
+ // Map not supported.
+ if (!xferOp.getPermutationMap().isMinorIdentity())
+ return failure();
+ // Must be a broadcast of a scalar.
+ auto broadcastOp = xferOp.getVector().getDefiningOp<vector::BroadcastOp>();
+ if (!broadcastOp || broadcastOp.getSource().getType().isa<VectorType>())
+ return failure();
+ // Construct a scalar store.
+ if (xferOp.getSource().getType().isa<MemRefType>()) {
+ rewriter.replaceOpWithNewOp<memref::StoreOp>(
+ xferOp, broadcastOp.getSource(), xferOp.getSource(),
+ xferOp.getIndices());
+ } else {
+ rewriter.replaceOpWithNewOp<tensor::InsertOp>(
+ xferOp, broadcastOp.getSource(), xferOp.getSource(),
+ xferOp.getIndices());
+ }
+ return success();
+ }
+};
} // namespace
void mlir::vector::transferOpflowOpt(Operation *rootOp) {
@@ -574,6 +671,13 @@ void mlir::vector::transferOpflowOpt(Operation *rootOp) {
opt.removeDeadOp();
}
+void mlir::vector::populateScalarVectorTransferLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns
+ .add<FoldScalarExtractOfTransferRead, FoldScalarTransferWriteOfBroadcast>(
+ patterns.getContext(), benefit);
+}
+
void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
new file mode 100644
index 0000000000000..d34b9c3091f69
--- /dev/null
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt %s -test-scalar-vector-transfer-lowering -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @transfer_read_0d(
+// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index
+// CHECK: %[[r:.*]] = memref.load %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+// CHECK: return %[[r]]
+func.func @transfer_read_0d(%m: memref<?x?x?xf32>, %idx: index) -> f32 {
+ %cst = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %m[%idx, %idx, %idx], %cst : memref<?x?x?xf32>, vector<f32>
+ %1 = vector.extractelement %0[] : vector<f32>
+ return %1 : f32
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @transfer_read_1d(
+// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[idx2:.*]]: index
+// CHECK: %[[added:.*]] = affine.apply #[[$map]]()[%[[idx]], %[[idx2]]]
+// CHECK: %[[r:.*]] = memref.load %[[m]][%[[idx]], %[[idx]], %[[added]]]
+// CHECK: return %[[r]]
+func.func @transfer_read_1d(%m: memref<?x?x?xf32>, %idx: index, %idx2: index) -> f32 {
+ %cst = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ %0 = vector.transfer_read %m[%idx, %idx, %idx], %cst {in_bounds = [true]} : memref<?x?x?xf32>, vector<5xf32>
+ %1 = vector.extractelement %0[%idx2 : index] : vector<5xf32>
+ return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor_transfer_read_0d(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?x?xf32>, %[[idx:.*]]: index
+// CHECK: %[[r:.*]] = tensor.extract %[[t]][%[[idx]], %[[idx]], %[[idx]]]
+// CHECK: return %[[r]]
+func.func @tensor_transfer_read_0d(%t: tensor<?x?x?xf32>, %idx: index) -> f32 {
+ %cst = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %t[%idx, %idx, %idx], %cst : tensor<?x?x?xf32>, vector<f32>
+ %1 = vector.extractelement %0[] : vector<f32>
+ return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_write_0d(
+// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
+// CHECK: memref.store %[[f]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+func.func @transfer_write_0d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
+ %0 = vector.broadcast %f : f32 to vector<f32>
+ vector.transfer_write %0, %m[%idx, %idx, %idx] : vector<f32>, memref<?x?x?xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_write_1d(
+// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
+// CHECK: memref.store %[[f]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+func.func @transfer_write_1d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
+ %0 = vector.broadcast %f : f32 to vector<1xf32>
+ vector.transfer_write %0, %m[%idx, %idx, %idx] : vector<1xf32>, memref<?x?x?xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor_transfer_write_0d(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
+// CHECK: %[[r:.*]] = tensor.insert %[[f]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]]
+// CHECK: return %[[r]]
+func.func @tensor_transfer_write_0d(%t: tensor<?x?x?xf32>, %idx: index, %f: f32) -> tensor<?x?x?xf32> {
+ %0 = vector.broadcast %f : f32 to vector<f32>
+ %1 = vector.transfer_write %0, %t[%idx, %idx, %idx] : vector<f32>, tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index b033186580940..48dc95a1a431a 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -462,6 +462,33 @@ struct TestVectorTransferFullPartialSplitPatterns
}
};
+struct TestScalarVectorTransferLoweringPatterns
+ : public PassWrapper<TestScalarVectorTransferLoweringPatterns,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestScalarVectorTransferLoweringPatterns)
+
+ StringRef getArgument() const final {
+ return "test-scalar-vector-transfer-lowering";
+ }
+ StringRef getDescription() const final {
+ return "Test lowering of scalar vector transfers to memref loads/stores.";
+ }
+ TestScalarVectorTransferLoweringPatterns() = default;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<AffineDialect, memref::MemRefDialect, tensor::TensorDialect,
+ vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+ vector::populateScalarVectorTransferLoweringPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestVectorTransferOpt
: public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
@@ -869,6 +896,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
+ PassRegistration<TestScalarVectorTransferLoweringPatterns>();
+
PassRegistration<TestVectorTransferOpt>();
PassRegistration<TestVectorTransferLoweringPatterns>();
More information about the Mlir-commits
mailing list