[Mlir-commits] [mlir] fb7ec1f - [mlir] Use VectorTransferPermutationMapLoweringPatterns in VectorToSCF
Matthias Springer
llvmlistbot at llvm.org
Tue May 18 22:46:35 PDT 2021
Author: Matthias Springer
Date: 2021-05-19T14:46:19+09:00
New Revision: fb7ec1f1873c82b758d606dc7e5b4687fc68dce2
URL: https://github.com/llvm/llvm-project/commit/fb7ec1f1873c82b758d606dc7e5b4687fc68dce2
DIFF: https://github.com/llvm/llvm-project/commit/fb7ec1f1873c82b758d606dc7e5b4687fc68dce2.diff
LOG: [mlir] Use VectorTransferPermutationMapLoweringPatterns in VectorToSCF
VectorTransferPermutationMapLoweringPatterns can be enabled via a pass option. These additional patterns lower permutation maps to minor identity maps with broadcasting, if possible, allowing for more efficient vector load/stores. The option is deactivated by default.
Differential Revision: https://reviews.llvm.org/D102593
Added:
mlir/test/Dialect/Vector/vector-transfer-lowering-to-scf.mlir
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
Removed:
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b440578754da..a1c40a741677 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -521,6 +521,9 @@ def ConvertVectorToSCF : FunctionPass<"convert-vector-to-scf"> {
"Perform full unrolling when converting vector transfers to SCF">,
Option<"targetRank", "target-rank", "unsigned", /*default=*/"1",
"Target vector rank to which transfer ops should be lowered">,
+ Option<"lowerPermutationMaps", "lower-permutation-maps", "bool",
+ /*default=*/"false", "Replace permutation maps with vector "
+ "transposes/broadcasts before lowering transfer ops">
];
}
diff --git a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
index 03765cb5532c..a999c4a1fcfc 100644
--- a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
+++ b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
@@ -50,6 +50,7 @@ class RewritePatternSet;
struct VectorTransferToSCFOptions {
bool unroll = false;
unsigned targetRank = 1;
+ bool lowerPermutationMaps = false;
VectorTransferToSCFOptions &setUnroll(bool u) {
unroll = u;
@@ -60,6 +61,11 @@ struct VectorTransferToSCFOptions {
targetRank = r;
return *this;
}
+
+ VectorTransferToSCFOptions &setLowerPermutationMaps(bool l) {
+ lowerPermutationMaps = l;
+ return *this;
+ }
};
/// Collect a set of patterns to convert from the Vector dialect to SCF + std.
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index d0e65a1d1c94..4b5f5ce8e035 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
@@ -86,9 +87,16 @@ void populateVectorSlicesLoweringPatterns(RewritePatternSet &patterns);
/// Collect a set of transfer read/write lowering patterns.
///
/// These patterns lower transfer ops to simpler ops like `vector.load`,
-/// `vector.store` and `vector.broadcast`.
+/// `vector.store` and `vector.broadcast`. Includes all patterns of
+/// populateVectorTransferPermutationMapLoweringPatterns.
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns);
+/// Collect a set of transfer read/write lowering patterns that simplify the
+/// permutation map (e.g., converting it to a minor identity map) by inserting
+/// broadcasts and transposes.
+void populateVectorTransferPermutationMapLoweringPatterns(
+ RewritePatternSet &patterns);
+
/// These patterns materialize masks for various vector ops such as transfers.
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
bool enableIndexOptimizations);
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 9972bcf5a3ae..54783a7a3992 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -264,6 +264,7 @@ static BufferAllocs allocBuffers(OpTy xferOp) {
if (xferOp.mask()) {
auto maskType = MemRefType::get({}, xferOp.mask().getType());
auto maskBuffer = memref_alloca(maskType).value;
+ b.setInsertionPoint(xferOp);
memref_store(xferOp.mask(), maskBuffer);
result.maskBuffer = memref_load(maskBuffer);
}
@@ -476,10 +477,11 @@ struct Strategy<TransferWriteOp> {
};
template <typename OpTy>
-LogicalResult checkPrepareXferOp(OpTy xferOp, unsigned targetRank) {
+LogicalResult checkPrepareXferOp(OpTy xferOp,
+ VectorTransferToSCFOptions options) {
if (xferOp->hasAttr(kPassLabel))
return failure();
- if (xferOp.getVectorType().getRank() <= targetRank)
+ if (xferOp.getVectorType().getRank() <= options.targetRank)
return failure();
return success();
}
@@ -513,7 +515,7 @@ struct PrepareTransferReadConversion
LogicalResult matchAndRewrite(TransferReadOp xferOp,
PatternRewriter &rewriter) const override {
- if (checkPrepareXferOp(xferOp, options.targetRank).failed())
+ if (checkPrepareXferOp(xferOp, options).failed())
return failure();
ScopedContext scope(rewriter, xferOp.getLoc());
@@ -561,7 +563,7 @@ struct PrepareTransferWriteConversion
LogicalResult matchAndRewrite(TransferWriteOp xferOp,
PatternRewriter &rewriter) const override {
- if (checkPrepareXferOp(xferOp, options.targetRank).failed())
+ if (checkPrepareXferOp(xferOp, options).failed())
return failure();
ScopedContext scope(rewriter, xferOp.getLoc());
@@ -1160,12 +1162,23 @@ struct ConvertVectorToSCFPass
ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
this->fullUnroll = options.unroll;
this->targetRank = options.targetRank;
+ this->lowerPermutationMaps = options.lowerPermutationMaps;
}
void runOnFunction() override {
VectorTransferToSCFOptions options;
- options.setUnroll(fullUnroll);
- options.setTargetRank(targetRank);
+ options.unroll = fullUnroll;
+ options.targetRank = targetRank;
+ options.lowerPermutationMaps = lowerPermutationMaps;
+
+ // Lower permutation maps first.
+ if (lowerPermutationMaps) {
+ RewritePatternSet lowerTransferPatterns(getFunction().getContext());
+ mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
+ lowerTransferPatterns);
+ (void)applyPatternsAndFoldGreedily(getFunction(),
+ std::move(lowerTransferPatterns));
+ }
RewritePatternSet patterns(getFunction().getContext());
populateVectorToSCFConversionPatterns(patterns, options);
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index c7a0623f3b32..f529167a2526 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2934,8 +2934,8 @@ struct TransferWriteInsertPattern
/// - The op has no mask.
struct TransferReadToVectorLoadLowering
: public OpRewritePattern<vector::TransferReadOp> {
- TransferReadToVectorLoadLowering(MLIRContext *context)
- : OpRewritePattern<vector::TransferReadOp>(context) {}
+ using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
LogicalResult matchAndRewrite(vector::TransferReadOp read,
PatternRewriter &rewriter) const override {
SmallVector<unsigned, 4> broadcastedDims;
@@ -3009,8 +3009,8 @@ struct TransferReadToVectorLoadLowering
/// - The op has no mask.
struct TransferWriteToVectorStoreLowering
: public OpRewritePattern<vector::TransferWriteOp> {
- TransferWriteToVectorStoreLowering(MLIRContext *context)
- : OpRewritePattern<vector::TransferWriteOp>(context) {}
+ using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
PatternRewriter &rewriter) const override {
// TODO: Support non-minor-identity maps
@@ -3086,6 +3086,7 @@ struct TransferReadPermutationLowering
if (permutationMap.isIdentity())
return failure();
+ permutationMap = map.getPermutationMap(permutation, op.getContext());
// Caluclate the map of the new read by applying the inverse permutation.
permutationMap = inversePermutation(permutationMap);
AffineMap newMap = permutationMap.compose(map);
@@ -4149,13 +4150,18 @@ void mlir::vector::populateVectorTransposeLoweringPatterns(
patterns.getContext());
}
+void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<TransferReadPermutationLowering,
+ TransferWritePermutationLowering, TransferOpReduceRank>(
+ patterns.getContext());
+}
+
void mlir::vector::populateVectorTransferLoweringPatterns(
RewritePatternSet &patterns) {
- patterns
- .add<TransferReadToVectorLoadLowering, TransferWriteToVectorStoreLowering,
- TransferReadPermutationLowering, TransferWritePermutationLowering,
- TransferOpReduceRank>(
- patterns.getContext());
+ patterns.add<TransferReadToVectorLoadLowering,
+ TransferWriteToVectorStoreLowering>(patterns.getContext());
+ populateVectorTransferPermutationMapLoweringPatterns(patterns);
}
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-lowering-to-scf.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering-to-scf.mlir
new file mode 100644
index 000000000000..5547a79957c8
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transfer-lowering-to-scf.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -convert-vector-to-scf='lower-permutation-maps=true' -split-input-file | FileCheck %s
+
+// Ensure that the permutation map is lowered (by inserting a transpose op)
+// before lowering the vector.transfer_read.
+
+// CHECK-LABEL: func @transfer_read_2d_mask_transposed(
+// CHECK-DAG: %[[PADDING:.*]] = constant dense<-4.200000e+01> : vector<9xf32>
+// CHECK-DAG: %[[MASK:.*]] = constant dense<{{.*}}> : vector<9x4xi1>
+// CHECK: %[[MASK_MEM:.*]] = memref.alloca() : memref<vector<4x9xi1>>
+// CHECK: %[[MASK_T:.*]] = vector.transpose %[[MASK]], [1, 0] : vector<9x4xi1> to vector<4x9xi1>
+// CHECK: memref.store %[[MASK_T]], %[[MASK_MEM]][] : memref<vector<4x9xi1>>
+// CHECK: %[[MASK_CASTED:.*]] = vector.type_cast %[[MASK_MEM]] : memref<vector<4x9xi1>> to memref<4xvector<9xi1>>
+// CHECK: scf.for {{.*}} {
+// CHECK: scf.if {{.*}} {
+// CHECK: %[[MASK_LOADED:.*]] = memref.load %[[MASK_CASTED]][%{{.*}}] : memref<4xvector<9xi1>>
+// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}, %{{.*}}, %[[MASK_LOADED]] : memref<?x?xf32>, vector<9xf32>
+// CHECK: memref.store %[[READ]], %{{.*}} : memref<4xvector<9xf32>>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[RESULT:.*]] = memref.load %{{.*}} : memref<vector<4x9xf32>>
+// CHECK: %[[RESULT_T:.*]] = vector.transpose %[[RESULT]], [1, 0] : vector<4x9xf32> to vector<9x4xf32>
+// CHECK: return %[[RESULT_T]] : vector<9x4xf32>
+
+// Vector load with mask + transpose.
+func @transfer_read_2d_mask_transposed(
+ %A : memref<?x?xf32>, %base1: index, %base2: index) -> (vector<9x4xf32>) {
+ %fm42 = constant -42.0: f32
+ %mask = constant dense<[[1, 0, 1, 0], [0, 0, 1, 0],
+ [1, 1, 1, 1], [0, 1, 1, 0],
+ [1, 1, 1, 1], [1, 1, 1, 1],
+ [1, 1, 1, 1], [0, 0, 0, 0],
+ [1, 1, 1, 1]]> : vector<9x4xi1>
+ %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
+ {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} :
+ memref<?x?xf32>, vector<9x4xf32>
+ return %f : vector<9x4xf32>
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir
deleted file mode 100644
index ad43b14a71d4..000000000000
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir
+++ /dev/null
@@ -1,41 +0,0 @@
-// Run test with and without test-vector-transfer-lowering-patterns.
-
-// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
-// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
-// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
-// RUN: FileCheck %s
-
-// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
-// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
-// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
-// RUN: FileCheck %s
-
-
-memref.global "private" @gv : memref<3x4xf32> = dense<[[0. , 1. , 2. , 3. ],
- [10., 11., 12., 13.],
- [20., 21., 22., 23.]]>
-
-// Vector load with transpose.
-func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
- %fm42 = constant -42.0: f32
- %f = vector.transfer_read %A[%base1, %base2], %fm42
- {in_bounds = [true, false], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} :
- memref<?x?xf32>, vector<3x9xf32>
- vector.print %f: vector<3x9xf32>
- return
-}
-
-func @entry() {
- %c0 = constant 0: index
- %c1 = constant 1: index
- %c2 = constant 2: index
- %c3 = constant 3: index
- %0 = memref.get_global @gv : memref<3x4xf32>
- %A = memref.cast %0 : memref<3x4xf32> to memref<?x?xf32>
-
- // 1. Read 2D vector from 2D memref with transpose.
- call @transfer_read_2d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
- // CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( 20, 0, -42, -42, -42, -42, -42, -42, -42 ) )
-
- return
-}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
index 20216cc6ba6e..7e6ef94c872d 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
@@ -3,7 +3,17 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
-// RUN: mlir-opt %s -convert-vector-to-scf=full-unroll=true -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-opt %s -convert-vector-to-scf='lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
index 03cdc3dd8e32..e3154e60b329 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir
@@ -3,7 +3,17 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
-// RUN: mlir-opt %s -convert-vector-to-scf=full-unroll=true -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-opt %s -convert-vector-to-scf='lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
index 00da9278d50c..344167bae467 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir
@@ -3,7 +3,17 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
-// RUN: mlir-opt %s -convert-vector-to-scf=full-unroll=true -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-opt %s -convert-vector-to-scf='lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true lower-permutation-maps=true' -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
More information about the Mlir-commits
mailing list