[Mlir-commits] [mlir] bdd37c9 - [mlir][tensor] Add some folders for insert/extract slice ops
Lei Zhang
llvmlistbot at llvm.org
Tue Oct 12 05:41:00 PDT 2021
Author: Lei Zhang
Date: 2021-10-12T08:40:54-04:00
New Revision: bdd37c9f494420aef954e63ab0315cc787d658b4
URL: https://github.com/llvm/llvm-project/commit/bdd37c9f494420aef954e63ab0315cc787d658b4
DIFF: https://github.com/llvm/llvm-project/commit/bdd37c9f494420aef954e63ab0315cc787d658b4.diff
LOG: [mlir][tensor] Add some folders for insert/extract slice ops
* Fold extract_slice immediately after insert_slice.
* Fold overlapping insert_slice.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D111439
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Linalg/hoisting.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index dc94c27c818c8..ee44abf870af5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1041,10 +1041,27 @@ foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
return success();
}
+/// If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice,
+/// we can return the InsertSliceOp's source directly.
+// TODO: This only checks the immediate producer; extend to go up the
+// insert/extract chain if the slices are disjoint.
+static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
+ auto insertOp = extractOp.source().getDefiningOp<InsertSliceOp>();
+
+ auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
+ if (insertOp && insertOp.source().getType() == extractOp.getType() &&
+ insertOp.isSameAs(extractOp, isSame))
+ return insertOp.source();
+
+ return {};
+}
+
OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute>) {
if (getSourceType() == getType() &&
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
return this->source();
+ if (Value slice = foldExtractAfterInsertSlice(*this))
+ return slice;
return OpFoldResult();
}
@@ -1085,11 +1102,41 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
}
+/// If we have two consecutive InsertSliceOp writing to the same slice, we
+/// can mutate the second InsertSliceOp's destination to the first one's.
+///
+/// Example:
+///
+/// ```mlir
+/// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
+/// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
+/// ```
+static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
+ auto prevInsertOp = insertOp.dest().getDefiningOp<InsertSliceOp>();
+
+ auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
+ if (!prevInsertOp ||
+ prevInsertOp.source().getType() != insertOp.source().getType() ||
+ !prevInsertOp.isSameAs(insertOp, isSame))
+ return failure();
+
+ insertOp.destMutable().assign(prevInsertOp.dest());
+ return success();
+}
+
OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
getSourceType() == getType() &&
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
return this->source();
+ if (succeeded(foldInsertAfterInsertSlice(*this)))
+ return getResult();
return OpFoldResult();
}
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 959f254a82cd1..97e1741b2f158 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -409,7 +409,8 @@ func @hoist_vector_transfer_pairs_tensor_and_slices(
// CHECK-DAG: tensor.insert_slice %[[STI2]] into %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<?x?xf32> into tensor<?x?xf32>
// Does not hoist, 2 slice / insert_slice for %arg8.
%sti2 = tensor.insert_slice %w2 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
- %st22 = tensor.extract_slice %sti2[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ // Extract with a
diff erent stride to make sure we cannot fold this extract with the above insert.
+ %st22 = tensor.extract_slice %sti2[%i, %c0][%step, %step][2, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%sti22 = tensor.insert_slice %st22 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
// CHECK: scf.yield {{.*}} : tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index d5171a6358637..1267e10f9ddbd 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -532,3 +532,30 @@ func @insert_tensor_cast_on_insert_slice_src(
: tensor<?x5x?xf32> into tensor<?x?x?xf32>
return %r : tensor<?x?x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @fold_extract_insert
+// CHECK-SAME: %{{.+}}: tensor<?x?x?xf32>, %[[SLICE:.+]]: tensor<4x?x8xf32>
+func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<4x?x8xf32>) {
+ %c0 = constant 0: index
+ %c1 = constant 1: index
+ %0 = tensor.insert_slice %slice into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
+ %1 = tensor.extract_slice %0[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<?x?x?xf32> to tensor<4x?x8xf32>
+ // CHECK: return %[[SLICE]]
+ return %1 : tensor<4x?x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_overlapping_insert
+// CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
+func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {
+ %c0 = constant 0: index
+ %c1 = constant 1: index
+ %0 = tensor.insert_slice %slice1 into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
+ // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE2]] into %[[INPUT]]
+ %1 = tensor.insert_slice %slice2 into %0[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
+ // CHECK: return %[[INSERT]]
+ return %1 : tensor<?x?x?xf32>
+}
More information about the Mlir-commits
mailing list