[Mlir-commits] [mlir] 27451a0 - [mlir][vector] Fold transfer ops and tensor.extract/insert_slice.

Matthias Springer llvmlistbot at llvm.org
Wed Sep 29 17:34:10 PDT 2021


Author: Matthias Springer
Date: 2021-09-30T09:28:00+09:00
New Revision: 27451a05ed4d13294182ec7e999a9d4f90bc0d12

URL: https://github.com/llvm/llvm-project/commit/27451a05ed4d13294182ec7e999a9d4f90bc0d12
DIFF: https://github.com/llvm/llvm-project/commit/27451a05ed4d13294182ec7e999a9d4f90bc0d12.diff

LOG: [mlir][vector] Fold transfer ops and tensor.extract/insert_slice.

* Fold vector.transfer_read and tensor.extract_slice.
* Fold vector.transfer_write and tensor.insert_slice.

Differential Revision: https://reviews.llvm.org/D110627

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/StandardOps/Utils/Utils.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h
index 1d2adc62d2714..835841ed382f6 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h
@@ -78,6 +78,12 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
 Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
                                       OpFoldResult ofr);
 
+/// Similar to the other overload, but converts multiple OpFoldResults into
+/// Values.
+SmallVector<Value>
+getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
+                                ArrayRef<OpFoldResult> valueOrAttrVec);
+
 /// Helper struct to build simple arithmetic quantities with minimal type
 /// inference support.
 struct ArithBuilder {

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 911a9c60c1451..f24da79b01673 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1277,6 +1277,7 @@ def Vector_TransferReadOp :
       "ArrayAttr":$inBounds)>
   ];
 
+  let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
 

diff  --git a/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp b/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp
index 3f66738cc78f6..d52b8dc3d349b 100644
--- a/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/StandardOps/Utils/Utils.cpp
@@ -58,6 +58,15 @@ Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
   return b.create<ConstantIndexOp>(loc, attr.getValue().getSExtValue());
 }
 
+SmallVector<Value>
+mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
+                                      ArrayRef<OpFoldResult> valueOrAttrVec) {
+  return llvm::to_vector<4>(
+      llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
+        return getValueOrCreateConstantIndexOp(b, loc, value);
+      }));
+}
+
 Value ArithBuilder::_and(Value lhs, Value rhs) {
   return b.create<AndOp>(loc, lhs, rhs);
 }

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 0b8f0efc16b02..4ce8011bb7fec 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/VectorUtils.h"
@@ -2649,6 +2650,74 @@ void TransferReadOp::getEffects(
                          SideEffects::DefaultResource::get());
 }
 
+namespace {
+/// Fold transfer_reads of a tensor.extract_slice op. E.g.:
+///
+/// ```
+/// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1]
+///     : tensor<?x?xf32> to tensor<?x?xf32>
+/// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]}
+///     : tensor<?x?xf32>, vector<4x5xf32>
+/// ```
+/// is rewritten to:
+/// ```
+/// %p0 = addi %a, %e : index
+/// %p1 = addi %b, %f : index
+/// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
+///     : tensor<?x?xf32>, vector<4x5xf32>
+/// ```
+struct FoldExtractSliceIntoTransferRead
+    : public OpRewritePattern<TransferReadOp> {
+public:
+  using OpRewritePattern<TransferReadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TransferReadOp xferOp,
+                                PatternRewriter &rewriter) const override {
+    if (xferOp.hasOutOfBoundsDim())
+      return failure();
+    if (!xferOp.permutation_map().isIdentity())
+      return failure();
+    if (xferOp.mask())
+      return failure();
+    auto extractOp = xferOp.source().getDefiningOp<tensor::ExtractSliceOp>();
+    if (!extractOp)
+      return failure();
+    if (!extractOp.hasUnitStride())
+      return failure();
+
+    int64_t rankReduced =
+        extractOp.getSourceType().getRank() - extractOp.getType().getRank();
+    SmallVector<Value> newIndices;
+    // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
+    // indices first.
+    for (int64_t i = 0; i < rankReduced; ++i) {
+      OpFoldResult offset = extractOp.getMixedOffsets()[i];
+      newIndices.push_back(getValueOrCreateConstantIndexOp(
+          rewriter, extractOp.getLoc(), offset));
+    }
+    for (auto it : llvm::enumerate(xferOp.indices())) {
+      OpFoldResult offset =
+          extractOp.getMixedOffsets()[it.index() + rankReduced];
+      newIndices.push_back(
+          rewriter.create<AddIOp>(xferOp->getLoc(), it.value(),
+                                  getValueOrCreateConstantIndexOp(
+                                      rewriter, extractOp.getLoc(), offset)));
+    }
+    SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
+    rewriter.replaceOpWithNewOp<TransferReadOp>(xferOp, xferOp.getVectorType(),
+                                                extractOp.source(), newIndices,
+                                                xferOp.padding(), inBounds);
+
+    return success();
+  }
+};
+} // namespace
+
+void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                 MLIRContext *context) {
+  results.add<FoldExtractSliceIntoTransferRead>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // TransferWriteOp
 //===----------------------------------------------------------------------===//
@@ -2958,11 +3027,61 @@ class foldWAW final : public OpRewritePattern<TransferWriteOp> {
     return failure();
   }
 };
+
+/// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
+/// could directly write to the insert_slice's destination. E.g.:
+///
+/// ```
+/// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]}
+///     : vector<4x5xf32>, tensor<4x5xf32>
+/// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1]
+///     : tensor<4x5xf32> into tensor<?x?xf32>
+/// ```
+/// is rewritten to:
+/// ```
+/// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
+///     : vector<4x5xf32>, tensor<?x?xf32>
+/// ```
+struct FoldInsertSliceIntoTransferWrite
+    : public OpRewritePattern<tensor::InsertSliceOp> {
+public:
+  using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
+                                PatternRewriter &rewriter) const override {
+    if (!insertOp.hasUnitStride())
+      return failure();
+    auto xferOp = insertOp.source().getDefiningOp<TransferWriteOp>();
+    if (!xferOp)
+      return failure();
+    if (xferOp.hasOutOfBoundsDim())
+      return failure();
+    if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
+      return failure();
+    if (xferOp.mask())
+      return failure();
+    // Fold only if the TransferWriteOp completely overwrites the `source` with
+    // a vector. I.e., the result of the TransferWriteOp is a new tensor who's
+    // content is the data of the vector.
+    if (!llvm::equal(xferOp.getVectorType().getShape(),
+                     xferOp.getShapedType().getShape()))
+      return failure();
+    if (!xferOp.permutation_map().isIdentity())
+      return failure();
+
+    SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
+        rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
+    SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
+    rewriter.replaceOpWithNewOp<TransferWriteOp>(
+        insertOp, xferOp.vector(), insertOp.dest(), indices, inBounds);
+    return success();
+  }
+};
 } // namespace
 
 void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<foldWAW>(context);
+  results.add<foldWAW, FoldInsertSliceIntoTransferWrite>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 2abcef93582b5..8b3674e59b7f4 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -960,3 +960,69 @@ func @insert_extract_to_shapecast(%arg0 : vector<1x1x4xf32>,
   %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
   return %0, %1 : vector<4xf32>, vector<1x1x4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+//   CHECK-DAG:   %[[c4:.*]] = constant 4 : index
+//   CHECK-DAG:   %[[c8:.*]] = constant 8 : index
+//       CHECK:   %[[add:.*]] = addi %[[s1]], %[[c4]]
+//       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<5x6xf32>
+//       CHECK:   return %[[r]]
+func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+  %c3 = constant 3 : index
+  %c4 = constant 4 : index
+  %cst = constant 0.0 : f32
+  %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
+  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32>
+  return %1 : vector<5x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+//   CHECK-DAG:   %[[c3:.*]] = constant 3 : index
+//   CHECK-DAG:   %[[c5:.*]] = constant 5 : index
+//   CHECK-DAG:   %[[c10:.*]] = constant 10 : index
+//       CHECK:   %[[add:.*]] = addi %[[s1]], %[[c3]]
+//       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c5]], %[[add]], %[[c10]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?x?xf32>, vector<5x6xf32>
+//       CHECK:   return %[[r]]
+func @transfer_read_of_extract_slice_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+  %c3 = constant 3 : index
+  %c4 = constant 4 : index
+  %cst = constant 0.0 : f32
+  %0 = tensor.extract_slice %t[5, %s1, 6] [1, %s2, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
+  %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
+  return %1 : vector<5x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
+//       CHECK:   %[[c3:.*]] = constant 3 : index
+//       CHECK:   %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x12xf32>
+//       CHECK:   return %[[r]]
+func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x12xf32> {
+  %c0 = constant 0 : index
+  %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+  %1 = tensor.insert_slice %0 into %t1[3, %s] [5, 6] [1, 1] : tensor<5x6xf32> into tensor<?x12xf32>
+  return %1 : tensor<?x12xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
+//   CHECK-DAG:   %[[c3:.*]] = constant 3 : index
+//   CHECK-DAG:   %[[c4:.*]] = constant 4 : index
+//       CHECK:   %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x?x12xf32>
+//       CHECK:   return %[[r]]
+func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
+  %c0 = constant 0 : index
+  %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+  %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
+  return %1 : tensor<?x?x12xf32>
+}


        


More information about the Mlir-commits mailing list