[Mlir-commits] [mlir] 39b9336 - [mlir][vector] Swap ExtractSliceOp(TransferWriteOp).

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 11 03:29:28 PDT 2022


Author: gysit
Date: 2022-04-11T10:28:53Z
New Revision: 39b933647444234afb3f3d14563d02e4b8ee1b38

URL: https://github.com/llvm/llvm-project/commit/39b933647444234afb3f3d14563d02e4b8ee1b38
DIFF: https://github.com/llvm/llvm-project/commit/39b933647444234afb3f3d14563d02e4b8ee1b38.diff

LOG: [mlir][vector] Swap ExtractSliceOp(TransferWriteOp).

Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is overwritten and inserted into another tensor. After this rewrite, the operations bufferize in-place since all of them work on the same %iter_arg slice.

For example:
```mlir
  %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
       : vector<8x16xf32>, tensor<8x16xf32>
  %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
       : tensor<8x16xf32> to tensor<?x?xf32>
  %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
       : tensor<?x?xf32> into tensor<27x37xf32>
```
folds to
```mlir
  %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
       : tensor<27x37xf32> to tensor<?x?xf32>
  %1 = vector.transfer_write %vec, %0[%c0, %c0]
       : vector<8x16xf32>, tensor<?x?xf32>
  %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
       : tensor<?x?xf32> into tensor<27x37xf32>

Reviewed By: nicolasvasilache, hanchung

Differential Revision: https://reviews.llvm.org/D123190

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7d9febec632ca..758478f8d7ff8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3534,11 +3534,114 @@ struct FoldInsertSliceIntoTransferWrite
     return success();
   }
 };
+
+/// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
+/// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
+/// overwritten and inserted into another tensor. After this rewrite, the
+/// operations bufferize in-place since all of them work on the same slice.
+///
+/// For example:
+/// ```mlir
+///   %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
+///        : vector<8x16xf32>, tensor<8x16xf32>
+///   %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
+///        : tensor<8x16xf32> to tensor<?x?xf32>
+///   %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
+///        : tensor<?x?xf32> into tensor<27x37xf32>
+/// ```
+/// folds to
+/// ```mlir
+///   %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
+///        : tensor<27x37xf32> to tensor<?x?xf32>
+///   %1 = vector.transfer_write %vec, %0[%c0, %c0]
+///        : vector<8x16xf32>, tensor<?x?xf32>
+///   %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
+///        : tensor<?x?xf32> into tensor<27x37xf32>
+/// ```
+struct SwapExtractSliceOfTransferWrite
+    : public OpRewritePattern<tensor::InsertSliceOp> {
+public:
+  using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
+                                PatternRewriter &rewriter) const override {
+    if (!insertOp.hasUnitStride())
+      return failure();
+    auto extractOp = insertOp.source().getDefiningOp<tensor::ExtractSliceOp>();
+    if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
+      return failure();
+    auto transferOp = extractOp.source().getDefiningOp<TransferWriteOp>();
+    if (!transferOp || !transferOp->hasOneUse())
+      return failure();
+
+    // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
+    // rank-reducing.
+    if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
+      return rewriter.notifyMatchFailure(insertOp,
+                                         "use-def chain is rank-reducing");
+    }
+
+    // Fail if tensor::ExtractSliceOp has non-zero offset.
+    if (!extractOp.hasZeroOffset()) {
+      return rewriter.notifyMatchFailure(insertOp,
+                                         "ExtractSliceOp has non-zero offset");
+    }
+
+    // Fail if tensor::TransferWriteOp has non-zero offset.
+    if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
+          return getConstantIntValue(value) == static_cast<int64_t>(0);
+        })) {
+      return rewriter.notifyMatchFailure(insertOp,
+                                         "TranferWriteOp has non-zero offset");
+    }
+
+    // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes 
diff er.
+    for (const auto &it :
+         llvm::zip(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
+      if (!isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it))) {
+        return rewriter.notifyMatchFailure(
+            insertOp, "InsertSliceOp and ExtractSliceOp sizes 
diff er");
+      }
+    }
+
+    // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
+    assert(transferOp.getVectorType().hasStaticShape() &&
+           "expected vector to have a static shape");
+    ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
+    SmallVector<int64_t> resultShape = applyPermutationMap(
+        transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
+    if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
+      return rewriter.notifyMatchFailure(
+          insertOp, "TransferWriteOp may not write the full tensor.");
+    }
+
+    // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
+    SmallVector<int64_t> newResultShape = applyPermutationMap(
+        transferOp.getPermutationMap(), insertOp.getSourceType().getShape());
+    SmallVector<bool> newInBounds;
+    for (const auto &en : enumerate(newResultShape))
+      newInBounds.push_back(en.value() == vectorShape[en.index()]);
+    auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+        extractOp.getLoc(), insertOp.getSourceType(), insertOp.dest(),
+        insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
+        insertOp.getMixedStrides());
+    auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
+        transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
+        transferOp.getIndices(), transferOp.getPermutationMapAttr(),
+        rewriter.getBoolArrayAttr(newInBounds));
+    rewriter.updateRootInPlace(insertOp, [&]() {
+      insertOp.sourceMutable().assign(newTransferWriteOp.getResult());
+    });
+    return success();
+  }
+};
+
 } // namespace
 
 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<FoldWaw, FoldInsertSliceIntoTransferWrite>(context);
+  results.add<FoldWaw, FoldInsertSliceIntoTransferWrite,
+              SwapExtractSliceOfTransferWrite>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 033f17ae2fe12..336d22c5808cf 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1149,6 +1149,82 @@ func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32>, %v
 
 // -----
 
+//       CHECK: #[[$MAP:[0-9a-z]+]] = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-LABEL: func @swap_extract_slice_transfer_write
+//  CHECK-SAME:   %[[VEC:.*]]: vector<8x4xf32>
+//  CHECK-SAME:   %[[INIT_TENSOR:.*]]: tensor<4x8xf32>,
+//  CHECK-SAME:   %[[ITER_ARG:.*]]: tensor<64x64xf32>,
+//  CHECK-SAME:   %[[IV:.*]]: index, %[[SZ:.*]]: index)
+func.func @swap_extract_slice_transfer_write(%arg0 : vector<8x4xf32>,
+                                             %arg1 : tensor<4x8xf32>,
+                                             %arg2 : tensor<64x64xf32>,
+                                             %iv : index, %sz : index) -> tensor<64x64xf32> {
+  //       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+
+  //       CHECK:   %[[T0:.*]] = tensor.extract_slice %[[ITER_ARG]]
+  //  CHECK-SAME:                 [%[[IV]], 16] [%[[SZ]], 8]
+  //       CHECK:   %[[T1:.*]] = vector.transfer_write %[[VEC]]
+  //  CHECK-SAME:                 %[[T0]][%[[C0]], %[[C0]]]
+  //  CHECK-SAME:                 in_bounds = [true, false]
+  //  CHECK-SAME:                 permutation_map = #[[$MAP]]
+  //       CHECK:   %[[T2:.*]] = tensor.insert_slice %[[T1]] into %[[ITER_ARG]]
+  //  CHECK-SAME:                 [%[[IV]], 16] [%[[SZ]], 8]
+  %0 = vector.transfer_write %arg0, %arg1[%c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<8x4xf32>, tensor<4x8xf32>
+  %1 = tensor.extract_slice %0[0, 0] [%sz, 8] [1, 1] : tensor<4x8xf32> to tensor<?x8xf32>
+  %2 = tensor.insert_slice %1 into %arg2[%iv, 16] [%sz, 8] [1, 1] : tensor<?x8xf32> into tensor<64x64xf32>
+
+  //       CHECK:   return %[[T2]]
+  func.return %2 : tensor<64x64xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_swap_extract_slice_transfer_write
+//  CHECK-SAME:   %[[VEC:.*]]: vector<8xf32>,
+//  CHECK-SAME:   %[[VEC_SMALL:.*]]: vector<4xf32>,
+//  CHECK-SAME:   %[[INIT_TENSOR:.*]]: tensor<8xf32>,
+//  CHECK-SAME:   %[[ITER_ARG:.*]]: tensor<64xf32>,
+//  CHECK-SAME:   %[[IV:.*]]: index, %[[SZ:.*]]: index)
+func.func @do_not_swap_extract_slice_transfer_write(%arg0 : vector<8xf32>,
+                                                    %arg1 : vector<4xf32>,
+                                                    %arg2 : tensor<8xf32>,
+                                                    %arg3 : tensor<64xf32>,
+                                                    %iv : index, %sz : index) -> (tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) {
+  //       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+
+  // Don't swap if the extracted and inserted slices do not match.
+  //       CHECK:   %[[T0:.*]] = vector.transfer_write %[[VEC]]
+  //       CHECK:   %[[T1:.*]] = tensor.extract_slice %[[T0]]
+  //       CHECK:   %[[T2:.*]] = tensor.insert_slice %[[T1]]
+  %0 = vector.transfer_write %arg0, %arg2[%c0] {in_bounds = [true]} : vector<8xf32>, tensor<8xf32>
+  %1 = tensor.extract_slice %0[0] [%iv] [1] : tensor<8xf32> to tensor<?xf32>
+  %2 = tensor.insert_slice %1 into %arg3[%iv] [%sz] [1] : tensor<?xf32> into tensor<64xf32>
+
+  // Don't swap if the TransferWriteOp takes a small vector.
+  //       CHECK:   %[[T3:.*]] = vector.transfer_write %[[VEC_SMALL]]
+  //       CHECK:   %[[T4:.*]] = tensor.extract_slice %[[T3]]
+  //       CHECK:   %[[T5:.*]] = tensor.insert_slice %[[T4]]
+  %3 = vector.transfer_write %arg1, %arg2[%c0] {in_bounds = [true]} : vector<4xf32>, tensor<8xf32>
+  %4 = tensor.extract_slice %3[0] [%sz] [1] : tensor<8xf32> to tensor<?xf32>
+  %5 = tensor.insert_slice %4 into %arg3[%iv] [%sz] [1] : tensor<?xf32> into tensor<64xf32>
+
+  // Don't swap if the one of the operations is rank-reducing.
+  //       CHECK:   %[[T6:.*]] = vector.transfer_write %[[VEC]]
+  //       CHECK:   %[[T7:.*]] = tensor.extract_slice %[[T6]]
+  //       CHECK:   %[[T8:.*]] = tensor.insert_slice %[[T7]]
+  %6 = vector.transfer_write %arg0, %arg2[%c0] {in_bounds = [true]} : vector<8xf32>, tensor<8xf32>
+  %7 = tensor.extract_slice %6[0] [1] [1] : tensor<8xf32> to tensor<f32>
+  %8 = tensor.insert_slice %7 into %arg3[%iv] [1] [1] : tensor<f32> into tensor<64xf32>
+
+  //       CHECK:   return %[[T2]], %[[T5]], %[[T8]]
+  func.return %2, %5, %8 : tensor<64xf32>, tensor<64xf32>, tensor<64xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @vector_multi_reduction_single_parallel(
 //  CHECK-SAME:     %[[v:.*]]: vector<2xf32>
 func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>) -> vector<2xf32> {


        


More information about the Mlir-commits mailing list