[Mlir-commits] [mlir] 9d59705 - [mlir][tensor] Fold round-tripping extract/insert slice ops

Lei Zhang llvmlistbot at llvm.org
Mon Sep 19 09:59:09 PDT 2022


Author: Lei Zhang
Date: 2022-09-19T12:58:52-04:00
New Revision: 9d5970516960373dbf2cfff395009ea3f75d5919

URL: https://github.com/llvm/llvm-project/commit/9d5970516960373dbf2cfff395009ea3f75d5919
DIFF: https://github.com/llvm/llvm-project/commit/9d5970516960373dbf2cfff395009ea3f75d5919.diff

LOG: [mlir][tensor] Fold round-tripping extract/insert slice ops

Reviewed By: ThomasRaoux, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D133909

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 1960232b5f4ec..dca5487e21713 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1684,6 +1684,24 @@ static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
   return success();
 }
 
+/// Folds round-trip extract/insert slice op pairs.
+/// Example:
+/// ```mlir
+/// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
+/// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
+/// ```
+/// can be folded into %val.
+static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
+  auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
+
+  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
+  if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
+      !extractOp.isSameAs(insertOp, isSame))
+    return nullptr;
+
+  return extractOp.getSource();
+}
+
 OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
   if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
       getSourceType() == getType() &&
@@ -1691,6 +1709,8 @@ OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
     return this->getSource();
   if (succeeded(foldInsertAfterInsertSlice(*this)))
     return getResult();
+  if (auto result = foldInsertAfterExtractSlice(*this))
+    return result;
   return OpFoldResult();
 }
 

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 83ef943abb7df..6d73d87ec2d23 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -639,7 +639,7 @@ func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4
   %c1 = arith.constant 1: index
   %0 = tensor.insert_slice %slice1 into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
   // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE2]] into %[[INPUT]]
-  %1 = tensor.insert_slice %slice2 into %0[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
+  %1 = tensor.insert_slice %slice2 into %0[0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
   // CHECK: return %[[INSERT]]
   return %1 : tensor<?x?x?xf32>
 }
@@ -1443,7 +1443,7 @@ func.func @cast_extract_slice_rank_reduce(%arg0 : tensor<128x512xf32>, %s : inde
 // -----
 
 // CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices(
-//  CHECK-SAME:     %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>, 
+//  CHECK-SAME:     %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>,
 //  CHECK-SAME:     %[[arg1:[0-9a-z]*]]: tensor<?x?xf32>,
 //  CHECK-SAME:     %[[num_threads:[0-9a-z]*]]: index
 func.func @canonicalize_parallel_insert_slice_indices(
@@ -1470,7 +1470,7 @@ func.func @canonicalize_parallel_insert_slice_indices(
 // -----
 
 // CHECK-LABEL: func.func @dont_fold_parallel_insert_slice(
-//  CHECK-SAME:     %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>, 
+//  CHECK-SAME:     %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>,
 //  CHECK-SAME:     %[[arg1:[0-9a-z]*]]: tensor<1x5xf32>)
 func.func @dont_fold_parallel_insert_slice(
     %arg0 : tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32>
@@ -1487,3 +1487,39 @@ func.func @dont_fold_parallel_insert_slice(
   }
   return %2 : tensor<1x5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @fold_insert_slice_after_extract_slice
+//  CHECK-SAME: (%[[INPUT:.+]]: tensor<1x2x2x4xf32>)
+func.func @fold_insert_slice_after_extract_slice(%input: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.extract_slice %input[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32>
+  %1 = tensor.insert_slice %0 into %input[%c0, 0, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32>
+  // CHECK: return %[[INPUT]]
+  return %1: tensor<1x2x2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @dont_fold_mismatched_source_dst
+func.func @dont_fold_mismatched_source_dst(%input0: tensor<1x2x2x4xf32>, %input1: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {
+  %c0 = arith.constant 0 : index
+  // CHECK: tensor.extract_slice
+  %0 = tensor.extract_slice %input0[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32>
+  // CHECK: tensor.insert_slice
+  %1 = tensor.insert_slice %0 into %input1[%c0, 0, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32>
+  return %1: tensor<1x2x2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @dont_fold_mismatched_parameters
+func.func @dont_fold_mismatched_parameters(%input: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {
+  %c0 = arith.constant 0 : index
+  // CHECK: tensor.extract_slice
+  %0 = tensor.extract_slice %input[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32>
+  // CHECK: tensor.insert_slice
+  %1 = tensor.insert_slice %0 into %input[%c0, 1, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32>
+  return %1: tensor<1x2x2x4xf32>
+}


        


More information about the Mlir-commits mailing list