[Mlir-commits] [mlir] bdd37c9 - [mlir][tensor] Add some folders for insert/extract slice ops

Lei Zhang llvmlistbot at llvm.org
Tue Oct 12 05:41:00 PDT 2021


Author: Lei Zhang
Date: 2021-10-12T08:40:54-04:00
New Revision: bdd37c9f494420aef954e63ab0315cc787d658b4

URL: https://github.com/llvm/llvm-project/commit/bdd37c9f494420aef954e63ab0315cc787d658b4
DIFF: https://github.com/llvm/llvm-project/commit/bdd37c9f494420aef954e63ab0315cc787d658b4.diff

LOG: [mlir][tensor] Add some folders for insert/extract slice ops

* Fold extract_slice immediately after insert_slice.
* Fold overlapping insert_slice.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D111439

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Linalg/hoisting.mlir
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index dc94c27c818c8..ee44abf870af5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1041,10 +1041,27 @@ foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
   return success();
 }
 
+/// If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice,
+/// we can return the InsertSliceOp's source directly.
+// TODO: This only checks the immediate producer; extend to go up the
+// insert/extract chain if the slices are disjoint.
+static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
+  auto insertOp = extractOp.source().getDefiningOp<InsertSliceOp>();
+
+  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
+  if (insertOp && insertOp.source().getType() == extractOp.getType() &&
+      insertOp.isSameAs(extractOp, isSame))
+    return insertOp.source();
+
+  return {};
+}
+
 OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute>) {
   if (getSourceType() == getType() &&
       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
     return this->source();
+  if (Value slice = foldExtractAfterInsertSlice(*this))
+    return slice;
   return OpFoldResult();
 }
 
@@ -1085,11 +1102,41 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
 }
 
+/// If we have two consecutive InsertSliceOp writing to the same slice, we
+/// can mutate the second InsertSliceOp's destination to the first one's.
+///
+/// Example:
+///
+/// ```mlir
+///   %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
+///   %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+///   %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
+/// ```
+static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
+  auto prevInsertOp = insertOp.dest().getDefiningOp<InsertSliceOp>();
+
+  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
+  if (!prevInsertOp ||
+      prevInsertOp.source().getType() != insertOp.source().getType() ||
+      !prevInsertOp.isSameAs(insertOp, isSame))
+    return failure();
+
+  insertOp.destMutable().assign(prevInsertOp.dest());
+  return success();
+}
+
 OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
   if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
       getSourceType() == getType() &&
       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
     return this->source();
+  if (succeeded(foldInsertAfterInsertSlice(*this)))
+    return getResult();
   return OpFoldResult();
 }
 

diff  --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 959f254a82cd1..97e1741b2f158 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -409,7 +409,8 @@ func @hoist_vector_transfer_pairs_tensor_and_slices(
       // CHECK-DAG:     tensor.insert_slice %[[STI2]] into %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<?x?xf32> into tensor<?x?xf32>
       // Does not hoist, 2 slice / insert_slice for %arg8.
       %sti2 = tensor.insert_slice %w2 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
-      %st22 = tensor.extract_slice %sti2[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+      // Extract with a 
diff erent stride to make sure we cannot fold this extract with the above insert.
+      %st22 = tensor.extract_slice %sti2[%i, %c0][%step, %step][2, 1] : tensor<?x?xf32> to tensor<?x?xf32>
       %sti22 = tensor.insert_slice %st22 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
 
       // CHECK:     scf.yield {{.*}} : tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index d5171a6358637..1267e10f9ddbd 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -532,3 +532,30 @@ func @insert_tensor_cast_on_insert_slice_src(
     : tensor<?x5x?xf32> into tensor<?x?x?xf32>
   return %r : tensor<?x?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @fold_extract_insert
+//  CHECK-SAME: %{{.+}}: tensor<?x?x?xf32>, %[[SLICE:.+]]: tensor<4x?x8xf32>
+func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<4x?x8xf32>) {
+  %c0 = constant 0: index
+  %c1 = constant 1: index
+  %0 = tensor.insert_slice %slice into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
+  %1 = tensor.extract_slice %0[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<?x?x?xf32> to tensor<4x?x8xf32>
+  // CHECK: return %[[SLICE]]
+  return %1 : tensor<4x?x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_overlapping_insert
+//  CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
+func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {
+  %c0 = constant 0: index
+  %c1 = constant 1: index
+  %0 = tensor.insert_slice %slice1 into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
+  // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE2]] into %[[INPUT]]
+  %1 = tensor.insert_slice %slice2 into %0[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
+  // CHECK: return %[[INSERT]]
+  return %1 : tensor<?x?x?xf32>
+}


        


More information about the Mlir-commits mailing list