[Mlir-commits] [mlir] 96bc223 - [mlir][linalg] Enhance FoldInsertPadIntoFill to support op chain

Lei Zhang llvmlistbot at llvm.org
Mon Feb 28 13:54:59 PST 2022


Author: Lei Zhang
Date: 2022-02-28T16:51:17-05:00
New Revision: 96bc2233c49ba1fff67c87dd0ae2fe8a46ca49e9

URL: https://github.com/llvm/llvm-project/commit/96bc2233c49ba1fff67c87dd0ae2fe8a46ca49e9
DIFF: https://github.com/llvm/llvm-project/commit/96bc2233c49ba1fff67c87dd0ae2fe8a46ca49e9.diff

LOG: [mlir][linalg] Enhance FoldInsertPadIntoFill to support op chain

If we have a chain of `tensor.insert_slice` ops inserting some
`tensor.pad` op into a `linalg.fill` and ranges do not overlap,
we can also elide the `tensor.pad` later.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D120446

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e0ce704648365..24f492aed936a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -456,7 +456,48 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
     if (!srcPadOp)
       return failure();
 
-    auto dstFillOp = insertOp.dest().getDefiningOp<linalg::FillOp>();
+    if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
+      return failure();
+
+    // Walk back the tensor.insert_slice chain and find the first destination
+    // value at the start of the chain.
+    Value firstDest = insertOp.dest();
+    while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
+      if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
+        return failure();
+
+      // Make sure the range of values accessed are disjoint. Without this, we
+      // cannot fold tensor.pad away.
+      bool disjoint = false;
+      for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
+        // If the dimension has dynamic offset/size, we cannot guarantee
+        // disjoint. So just skip it.
+        if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
+            insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
+            prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
+          continue;
+
+        // Get the range start and end, inclusively for both.
+        int64_t prevStart = prevOp.getStaticOffset(i);
+        int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
+                                          prevOp.getStaticStride(i);
+        int64_t nextStart = insertOp.getStaticOffset(i);
+        int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
+                                          insertOp.getStaticStride(i);
+        if (prevEnd < nextStart || nextEnd < prevStart) {
+          disjoint = true;
+          break;
+        }
+      }
+
+      if (!disjoint)
+        break;
+      firstDest = prevOp.dest();
+    }
+
+    // Check whether the first destination is a fill op. For overlapped cases,
+    // this also cannot be true.
+    auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
     if (!dstFillOp)
       return failure();
 

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index ca5546bd2697a..a2be638fd8078 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -642,3 +642,108 @@ func @insert_pad_into_fill(%input: tensor<?x?x?xf32>, %low0: index, %low1: index
   %0 = tensor.insert_slice %pad into %fill[0, 1, 2] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
   return %0: tensor<8x384x384xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @multi_insert_pad_into_fill
+//  CHECK-SAME: (%[[INPUT:.+]]: tensor<7x123x124xf32>, %[[A:.+]]: tensor<8x128x128xf32>, %[[OFFSET:.+]]: index)
+//       CHECK:   %[[FILL:.+]] = linalg.fill
+//       CHECK:   %[[INSERT0:.+]] = tensor.insert_slice %[[A]] into %[[FILL]][%[[OFFSET]], 0, 0] [8, 128, 128] [1, 1, 1]
+//       CHECK:   %[[INSERT1:.+]] = tensor.insert_slice %[[A]] into %[[INSERT0]][0, 128, %[[OFFSET]]] [8, 128, 128] [1, 1, 1]
+//       CHECK:                  tensor.insert_slice %[[INPUT]] into %[[INSERT1]][1, 2, 256] [7, 123, 124] [1, 1, 1]
+func @multi_insert_pad_into_fill(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> {
+  %f0 = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] {
+  ^bb0(%arg3: index, %arg4: index, %arg5: index):
+    tensor.yield %f0 : f32
+  } : tensor<7x123x124xf32> to tensor<8x128x128xf32>
+  %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32>
+  %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32>
+  %0 = tensor.insert_slice %a   into %fill[%offset, 0, 0]  [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+  %1 = tensor.insert_slice %a   into %0   [0, 128, %offset][8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+  %2 = tensor.insert_slice %pad into %1   [0, 0, 256]      [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+  return %2: tensor<8x384x384xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @multi_insert_pad_into_fill_overlap
+func @multi_insert_pad_into_fill_overlap(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> {
+  %f0 = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  // CHECK: tensor.pad
+  %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] {
+  ^bb0(%arg3: index, %arg4: index, %arg5: index):
+    tensor.yield %f0 : f32
+  } : tensor<7x123x124xf32> to tensor<8x128x128xf32>
+  %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32>
+  %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32>
+  %0 = tensor.insert_slice %a   into %fill[%offset, 0, 0]  [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+  %1 = tensor.insert_slice %a   into %0   [0, 0, 129]      [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+  // Range overlap with %1 at dim#3
+  %2 = tensor.insert_slice %pad into %1   [0, 0, 256]      [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+  return %2: tensor<8x384x384xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @multi_insert_pad_into_fill_overlap
+func @multi_insert_pad_into_fill_overlap(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> {
+  %f0 = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  // CHECK: tensor.pad
+  %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] {
+  ^bb0(%arg3: index, %arg4: index, %arg5: index):
+    tensor.yield %f0 : f32
+  } : tensor<7x123x124xf32> to tensor<8x128x128xf32>
+  %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32>
+  %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32>
+  %0 = tensor.insert_slice %a   into %fill[0, 0, %offset]  [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+  %1 = tensor.insert_slice %a   into %0   [0, 128, 255]    [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+  // Range overlap with %0 at dim#3
+  %2 = tensor.insert_slice %pad into %1   [0, 0, 256]      [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+  return %2: tensor<8x384x384xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @multi_insert_pad_into_fill
+func @multi_insert_pad_into_fill(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> {
+  %f0 = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  // CHECK-NOT: tensor.pad
+  %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] {
+  ^bb0(%arg3: index, %arg4: index, %arg5: index):
+    tensor.yield %f0 : f32
+  } : tensor<7x123x124xf32> to tensor<8x128x128xf32>
+  %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32>
+  %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32>
+  // Overlap btween %0 and %1 is fine but not with %2 is fine.
+  // CHECK-COUNT-3: tensor.insert_slice
+  %0 = tensor.insert_slice %a   into %fill[0, 0, %offset]  [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+  %1 = tensor.insert_slice %a   into %0   [0, 1, %offset]  [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+  %2 = tensor.insert_slice %pad into %1   [0, 256, 256]    [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+  return %2: tensor<8x384x384xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @multi_insert_pad_into_fill_mismatch
+func @multi_insert_pad_into_fill_mismatch(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> {
+  %f0 = arith.constant 0.0 : f32
+  %f1 = arith.constant 1.0 : f32
+  %c0 = arith.constant 0 : index
+  // CHECK: tensor.pad
+  %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] {
+  ^bb0(%arg3: index, %arg4: index, %arg5: index):
+    tensor.yield %f0 : f32
+  } : tensor<7x123x124xf32> to tensor<8x128x128xf32>
+  %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32>
+  // Different filling value than padding value.
+  %fill = linalg.fill(%f1, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32>
+  %0 = tensor.insert_slice %a   into %fill[%offset, 0, 0]  [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+  %1 = tensor.insert_slice %a   into %0   [0, 128, %offset][8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+  %2 = tensor.insert_slice %pad into %1   [0, 0, 256]      [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+  return %2: tensor<8x384x384xf32>
+}


        


More information about the Mlir-commits mailing list