[Mlir-commits] [mlir] 39b9336 - [mlir][vector] Swap ExtractSliceOp(TransferWriteOp).
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 11 03:29:28 PDT 2022
Author: gysit
Date: 2022-04-11T10:28:53Z
New Revision: 39b933647444234afb3f3d14563d02e4b8ee1b38
URL: https://github.com/llvm/llvm-project/commit/39b933647444234afb3f3d14563d02e4b8ee1b38
DIFF: https://github.com/llvm/llvm-project/commit/39b933647444234afb3f3d14563d02e4b8ee1b38.diff
LOG: [mlir][vector] Swap ExtractSliceOp(TransferWriteOp).
Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is overwritten and inserted into another tensor. After this rewrite, the operations bufferize in-place since all of them work on the same %iter_arg slice.
For example:
```mlir
%0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
: vector<8x16xf32>, tensor<8x16xf32>
%1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
: tensor<8x16xf32> to tensor<?x?xf32>
%r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
: tensor<?x?xf32> into tensor<27x37xf32>
```
folds to
```mlir
%0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
: tensor<27x37xf32> to tensor<?x?xf32>
%1 = vector.transfer_write %vec, %0[%c0, %c0]
: vector<8x16xf32>, tensor<?x?xf32>
%r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
: tensor<?x?xf32> into tensor<27x37xf32>
Reviewed By: nicolasvasilache, hanchung
Differential Revision: https://reviews.llvm.org/D123190
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7d9febec632ca..758478f8d7ff8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3534,11 +3534,114 @@ struct FoldInsertSliceIntoTransferWrite
return success();
}
};
+
+/// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
+/// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
+/// overwritten and inserted into another tensor. After this rewrite, the
+/// operations bufferize in-place since all of them work on the same slice.
+///
+/// For example:
+/// ```mlir
+/// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
+/// : vector<8x16xf32>, tensor<8x16xf32>
+/// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
+/// : tensor<8x16xf32> to tensor<?x?xf32>
+/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
+/// : tensor<?x?xf32> into tensor<27x37xf32>
+/// ```
+/// folds to
+/// ```mlir
+/// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
+/// : tensor<27x37xf32> to tensor<?x?xf32>
+/// %1 = vector.transfer_write %vec, %0[%c0, %c0]
+/// : vector<8x16xf32>, tensor<?x?xf32>
+/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
+/// : tensor<?x?xf32> into tensor<27x37xf32>
+/// ```
+struct SwapExtractSliceOfTransferWrite
+ : public OpRewritePattern<tensor::InsertSliceOp> {
+public:
+ using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
+ PatternRewriter &rewriter) const override {
+ if (!insertOp.hasUnitStride())
+ return failure();
+ auto extractOp = insertOp.source().getDefiningOp<tensor::ExtractSliceOp>();
+ if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
+ return failure();
+ auto transferOp = extractOp.source().getDefiningOp<TransferWriteOp>();
+ if (!transferOp || !transferOp->hasOneUse())
+ return failure();
+
+ // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
+ // rank-reducing.
+ if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
+ return rewriter.notifyMatchFailure(insertOp,
+ "use-def chain is rank-reducing");
+ }
+
+ // Fail if tensor::ExtractSliceOp has non-zero offset.
+ if (!extractOp.hasZeroOffset()) {
+ return rewriter.notifyMatchFailure(insertOp,
+ "ExtractSliceOp has non-zero offset");
+ }
+
+ // Fail if tensor::TransferWriteOp has non-zero offset.
+ if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
+ return getConstantIntValue(value) == static_cast<int64_t>(0);
+ })) {
+ return rewriter.notifyMatchFailure(insertOp,
+ "TranferWriteOp has non-zero offset");
+ }
+
+ // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes
diff er.
+ for (const auto &it :
+ llvm::zip(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
+ if (!isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it))) {
+ return rewriter.notifyMatchFailure(
+ insertOp, "InsertSliceOp and ExtractSliceOp sizes
diff er");
+ }
+ }
+
+ // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
+ assert(transferOp.getVectorType().hasStaticShape() &&
+ "expected vector to have a static shape");
+ ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
+ SmallVector<int64_t> resultShape = applyPermutationMap(
+ transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
+ if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
+ return rewriter.notifyMatchFailure(
+ insertOp, "TransferWriteOp may not write the full tensor.");
+ }
+
+ // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
+ SmallVector<int64_t> newResultShape = applyPermutationMap(
+ transferOp.getPermutationMap(), insertOp.getSourceType().getShape());
+ SmallVector<bool> newInBounds;
+ for (const auto &en : enumerate(newResultShape))
+ newInBounds.push_back(en.value() == vectorShape[en.index()]);
+ auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+ extractOp.getLoc(), insertOp.getSourceType(), insertOp.dest(),
+ insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
+ insertOp.getMixedStrides());
+ auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
+ transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
+ transferOp.getIndices(), transferOp.getPermutationMapAttr(),
+ rewriter.getBoolArrayAttr(newInBounds));
+ rewriter.updateRootInPlace(insertOp, [&]() {
+ insertOp.sourceMutable().assign(newTransferWriteOp.getResult());
+ });
+ return success();
+ }
+};
+
} // namespace
void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldWaw, FoldInsertSliceIntoTransferWrite>(context);
+ results.add<FoldWaw, FoldInsertSliceIntoTransferWrite,
+ SwapExtractSliceOfTransferWrite>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 033f17ae2fe12..336d22c5808cf 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1149,6 +1149,82 @@ func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32>, %v
// -----
+// CHECK: #[[$MAP:[0-9a-z]+]] = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-LABEL: func @swap_extract_slice_transfer_write
+// CHECK-SAME: %[[VEC:.*]]: vector<8x4xf32>
+// CHECK-SAME: %[[INIT_TENSOR:.*]]: tensor<4x8xf32>,
+// CHECK-SAME: %[[ITER_ARG:.*]]: tensor<64x64xf32>,
+// CHECK-SAME: %[[IV:.*]]: index, %[[SZ:.*]]: index)
+func.func @swap_extract_slice_transfer_write(%arg0 : vector<8x4xf32>,
+ %arg1 : tensor<4x8xf32>,
+ %arg2 : tensor<64x64xf32>,
+ %iv : index, %sz : index) -> tensor<64x64xf32> {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+
+ // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ITER_ARG]]
+ // CHECK-SAME: [%[[IV]], 16] [%[[SZ]], 8]
+ // CHECK: %[[T1:.*]] = vector.transfer_write %[[VEC]]
+ // CHECK-SAME: %[[T0]][%[[C0]], %[[C0]]]
+ // CHECK-SAME: in_bounds = [true, false]
+ // CHECK-SAME: permutation_map = #[[$MAP]]
+ // CHECK: %[[T2:.*]] = tensor.insert_slice %[[T1]] into %[[ITER_ARG]]
+ // CHECK-SAME: [%[[IV]], 16] [%[[SZ]], 8]
+ %0 = vector.transfer_write %arg0, %arg1[%c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<8x4xf32>, tensor<4x8xf32>
+ %1 = tensor.extract_slice %0[0, 0] [%sz, 8] [1, 1] : tensor<4x8xf32> to tensor<?x8xf32>
+ %2 = tensor.insert_slice %1 into %arg2[%iv, 16] [%sz, 8] [1, 1] : tensor<?x8xf32> into tensor<64x64xf32>
+
+ // CHECK: return %[[T2]]
+ func.return %2 : tensor<64x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_swap_extract_slice_transfer_write
+// CHECK-SAME: %[[VEC:.*]]: vector<8xf32>,
+// CHECK-SAME: %[[VEC_SMALL:.*]]: vector<4xf32>,
+// CHECK-SAME: %[[INIT_TENSOR:.*]]: tensor<8xf32>,
+// CHECK-SAME: %[[ITER_ARG:.*]]: tensor<64xf32>,
+// CHECK-SAME: %[[IV:.*]]: index, %[[SZ:.*]]: index)
+func.func @do_not_swap_extract_slice_transfer_write(%arg0 : vector<8xf32>,
+ %arg1 : vector<4xf32>,
+ %arg2 : tensor<8xf32>,
+ %arg3 : tensor<64xf32>,
+ %iv : index, %sz : index) -> (tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+
+ // Don't swap if the extracted and inserted slices do not match.
+ // CHECK: %[[T0:.*]] = vector.transfer_write %[[VEC]]
+ // CHECK: %[[T1:.*]] = tensor.extract_slice %[[T0]]
+ // CHECK: %[[T2:.*]] = tensor.insert_slice %[[T1]]
+ %0 = vector.transfer_write %arg0, %arg2[%c0] {in_bounds = [true]} : vector<8xf32>, tensor<8xf32>
+ %1 = tensor.extract_slice %0[0] [%iv] [1] : tensor<8xf32> to tensor<?xf32>
+ %2 = tensor.insert_slice %1 into %arg3[%iv] [%sz] [1] : tensor<?xf32> into tensor<64xf32>
+
+ // Don't swap if the TransferWriteOp takes a small vector.
+ // CHECK: %[[T3:.*]] = vector.transfer_write %[[VEC_SMALL]]
+ // CHECK: %[[T4:.*]] = tensor.extract_slice %[[T3]]
+ // CHECK: %[[T5:.*]] = tensor.insert_slice %[[T4]]
+ %3 = vector.transfer_write %arg1, %arg2[%c0] {in_bounds = [true]} : vector<4xf32>, tensor<8xf32>
+ %4 = tensor.extract_slice %3[0] [%sz] [1] : tensor<8xf32> to tensor<?xf32>
+ %5 = tensor.insert_slice %4 into %arg3[%iv] [%sz] [1] : tensor<?xf32> into tensor<64xf32>
+
+ // Don't swap if the one of the operations is rank-reducing.
+ // CHECK: %[[T6:.*]] = vector.transfer_write %[[VEC]]
+ // CHECK: %[[T7:.*]] = tensor.extract_slice %[[T6]]
+ // CHECK: %[[T8:.*]] = tensor.insert_slice %[[T7]]
+ %6 = vector.transfer_write %arg0, %arg2[%c0] {in_bounds = [true]} : vector<8xf32>, tensor<8xf32>
+ %7 = tensor.extract_slice %6[0] [1] [1] : tensor<8xf32> to tensor<f32>
+ %8 = tensor.insert_slice %7 into %arg3[%iv] [1] [1] : tensor<f32> into tensor<64xf32>
+
+ // CHECK: return %[[T2]], %[[T5]], %[[T8]]
+ func.return %2, %5, %8 : tensor<64xf32>, tensor<64xf32>, tensor<64xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @vector_multi_reduction_single_parallel(
// CHECK-SAME: %[[v:.*]]: vector<2xf32>
func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>) -> vector<2xf32> {
More information about the Mlir-commits
mailing list