[Mlir-commits] [mlir] 7ddeffe - [mlir] Lower permutation maps on TransferWriteOps
Matthias Springer
llvmlistbot at llvm.org
Sun May 16 23:39:09 PDT 2021
Author: Matthias Springer
Date: 2021-05-17T15:30:46+09:00
New Revision: 7ddeffee55766005327abbac85838225069cc164
URL: https://github.com/llvm/llvm-project/commit/7ddeffee55766005327abbac85838225069cc164
DIFF: https://github.com/llvm/llvm-project/commit/7ddeffee55766005327abbac85838225069cc164.diff
LOG: [mlir] Lower permutation maps on TransferWriteOps
Add TransferWritePermutationLowering, which replaces permutation maps of TransferWriteOps with vector.transpose.
Differential Revision: https://reviews.llvm.org/D102548
Added:
Modified:
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 1effce2f5679c..c7a0623f3b32a 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -3137,6 +3137,70 @@ struct TransferReadPermutationLowering
}
};
+/// Lower transfer_write op with permutation into a transfer_write with a
+/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
+/// Ex:
+/// vector.transfer_write %v ...
+/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
+/// into:
+/// %tmp = vector.transpose %v, [2, 0, 1]
+/// vector.transfer_write %tmp ...
+/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
+///
+/// vector.transfer_write %v ...
+/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
+/// into:
+/// %tmp = vector.transpose %v, [1, 0]
+/// %v = vector.transfer_write %tmp ...
+/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
+struct TransferWritePermutationLowering
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<unsigned> permutation;
+ AffineMap map = op.permutation_map();
+ if (map.isMinorIdentity())
+ return failure();
+ if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
+ return failure();
+
+ // Remove unused dims from the permutation map. E.g.:
+ // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
+ // comp = (d0, d1, d2) -> (d2, d0, d1)
+ auto comp = compressUnusedDims(map);
+ // Get positions of remaining result dims.
+ SmallVector<int64_t> indices;
+ llvm::transform(comp.getResults(), std::back_inserter(indices),
+ [](AffineExpr expr) {
+ return expr.dyn_cast<AffineDimExpr>().getPosition();
+ });
+
+ // Transpose mask operand.
+ Value newMask = op.mask()
+ ? rewriter.create<vector::TransposeOp>(op.getLoc(), op.mask(), indices)
+ : Value();
+
+ // Transpose in_bounds attribute.
+ ArrayAttr newInBounds = op.in_bounds()
+ ? transposeInBoundsAttr(rewriter, op.in_bounds().getValue(),
+ permutation)
+ : ArrayAttr();
+
+ // Generate new transfer_write operation.
+ Value newVec = rewriter.create<vector::TransposeOp>(
+ op.getLoc(), op.vector(), indices);
+ auto newMap = AffineMap::getMinorIdentityMap(
+ map.getNumDims(), map.getNumResults(), rewriter.getContext());
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+ op, Type(), newVec, op.source(), op.indices(), newMap, newMask,
+ newInBounds);
+
+ return success();
+ }
+};
+
/// Lower transfer_read op with broadcast in the leading dimensions into
/// transfer_read of lower rank + vector.broadcast.
/// Ex: vector.transfer_read ...
@@ -4089,7 +4153,8 @@ void mlir::vector::populateVectorTransferLoweringPatterns(
RewritePatternSet &patterns) {
patterns
.add<TransferReadToVectorLoadLowering, TransferWriteToVectorStoreLowering,
- TransferReadPermutationLowering, TransferOpReduceRank>(
+ TransferReadPermutationLowering, TransferWritePermutationLowering,
+ TransferOpReduceRank>(
patterns.getContext());
}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
index 28a267967942c..60bbadf59874a 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
@@ -267,3 +267,25 @@ func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?
vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
vector<7x14x8x16xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @transfer_write_permutations
+func @transfer_write_permutations(%arg0 : memref<?x?x?x?xf32>,
+ %v1 : vector<7x14x8x16xf32>, %v2 : vector<8x16xf32>) -> () {
+ // CHECK-DAG: %[[C0:.*]] = constant 0 : index
+ %c0 = constant 0 : index
+ %m = constant 1 : i1
+
+ %mask0 = splat %m : vector<7x14x8x16xi1>
+ vector.transfer_write %v1, %arg0[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, memref<?x?x?x?xf32>
+ // CHECK: %[[NEW_MASK0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xi1> to vector<8x14x16x7xi1>
+ // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xf32> to vector<8x14x16x7xf32>
+ // CHECK: vector.transfer_write %[[NEW_VEC0]], %arg0[%c0, %c0, %c0, %c0], %[[NEW_MASK0]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, memref<?x?x?x?xf32>
+
+ vector.transfer_write %v2, %arg0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>} : vector<8x16xf32>, memref<?x?x?x?xf32>
+ // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %{{.*}} [1, 0] : vector<8x16xf32> to vector<16x8xf32>
+ // CHECK: vector.transfer_write %[[NEW_VEC1]], %arg0[%c0, %c0, %c0, %c0] : vector<16x8xf32>, memref<?x?x?x?xf32>
+
+ return
+}
More information about the Mlir-commits
mailing list