[Mlir-commits] [mlir] 9d59705 - [mlir][tensor] Fold round-tripping extract/insert slice ops
Lei Zhang
llvmlistbot at llvm.org
Mon Sep 19 09:59:09 PDT 2022
Author: Lei Zhang
Date: 2022-09-19T12:58:52-04:00
New Revision: 9d5970516960373dbf2cfff395009ea3f75d5919
URL: https://github.com/llvm/llvm-project/commit/9d5970516960373dbf2cfff395009ea3f75d5919
DIFF: https://github.com/llvm/llvm-project/commit/9d5970516960373dbf2cfff395009ea3f75d5919.diff
LOG: [mlir][tensor] Fold round-tripping extract/insert slice ops
Reviewed By: ThomasRaoux, nicolasvasilache
Differential Revision: https://reviews.llvm.org/D133909
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 1960232b5f4ec..dca5487e21713 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1684,6 +1684,24 @@ static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
return success();
}
+/// Folds round-trip extract/insert slice op pairs.
+/// Example:
+/// ```mlir
+/// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
+/// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
+/// ```
+/// can be folded into %val.
+static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
+ auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
+
+ auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
+ if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
+ !extractOp.isSameAs(insertOp, isSame))
+ return nullptr;
+
+ return extractOp.getSource();
+}
+
OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
getSourceType() == getType() &&
@@ -1691,6 +1709,8 @@ OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
return this->getSource();
if (succeeded(foldInsertAfterInsertSlice(*this)))
return getResult();
+ if (auto result = foldInsertAfterExtractSlice(*this))
+ return result;
return OpFoldResult();
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 83ef943abb7df..6d73d87ec2d23 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -639,7 +639,7 @@ func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4
%c1 = arith.constant 1: index
%0 = tensor.insert_slice %slice1 into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE2]] into %[[INPUT]]
- %1 = tensor.insert_slice %slice2 into %0[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
+ %1 = tensor.insert_slice %slice2 into %0[0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
// CHECK: return %[[INSERT]]
return %1 : tensor<?x?x?xf32>
}
@@ -1443,7 +1443,7 @@ func.func @cast_extract_slice_rank_reduce(%arg0 : tensor<128x512xf32>, %s : inde
// -----
// CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices(
-// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>,
+// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>,
// CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[num_threads:[0-9a-z]*]]: index
func.func @canonicalize_parallel_insert_slice_indices(
@@ -1470,7 +1470,7 @@ func.func @canonicalize_parallel_insert_slice_indices(
// -----
// CHECK-LABEL: func.func @dont_fold_parallel_insert_slice(
-// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>,
+// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>,
// CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor<1x5xf32>)
func.func @dont_fold_parallel_insert_slice(
%arg0 : tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32>
@@ -1487,3 +1487,39 @@ func.func @dont_fold_parallel_insert_slice(
}
return %2 : tensor<1x5xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_insert_slice_after_extract_slice
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x2x2x4xf32>)
+func.func @fold_insert_slice_after_extract_slice(%input: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = tensor.extract_slice %input[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32>
+ %1 = tensor.insert_slice %0 into %input[%c0, 0, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32>
+ // CHECK: return %[[INPUT]]
+ return %1: tensor<1x2x2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @dont_fold_mismatched_source_dst
+func.func @dont_fold_mismatched_source_dst(%input0: tensor<1x2x2x4xf32>, %input1: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {
+ %c0 = arith.constant 0 : index
+ // CHECK: tensor.extract_slice
+ %0 = tensor.extract_slice %input0[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32>
+ // CHECK: tensor.insert_slice
+ %1 = tensor.insert_slice %0 into %input1[%c0, 0, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32>
+ return %1: tensor<1x2x2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @dont_fold_mismatched_parameters
+func.func @dont_fold_mismatched_parameters(%input: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {
+ %c0 = arith.constant 0 : index
+ // CHECK: tensor.extract_slice
+ %0 = tensor.extract_slice %input[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32>
+ // CHECK: tensor.insert_slice
+ %1 = tensor.insert_slice %0 into %input[%c0, 1, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32>
+ return %1: tensor<1x2x2x4xf32>
+}
More information about the Mlir-commits
mailing list