[Mlir-commits] [mlir] 96bc223 - [mlir][linalg] Enhance FoldInsertPadIntoFill to support op chain
Lei Zhang
llvmlistbot at llvm.org
Mon Feb 28 13:54:59 PST 2022
Author: Lei Zhang
Date: 2022-02-28T16:51:17-05:00
New Revision: 96bc2233c49ba1fff67c87dd0ae2fe8a46ca49e9
URL: https://github.com/llvm/llvm-project/commit/96bc2233c49ba1fff67c87dd0ae2fe8a46ca49e9
DIFF: https://github.com/llvm/llvm-project/commit/96bc2233c49ba1fff67c87dd0ae2fe8a46ca49e9.diff
LOG: [mlir][linalg] Enhance FoldInsertPadIntoFill to support op chain
If we have a chain of `tensor.insert_slice` ops inserting some
`tensor.pad` op into a `linalg.fill` and ranges do not overlap,
we can also elide the `tensor.pad` later.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D120446
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e0ce704648365..24f492aed936a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -456,7 +456,48 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
if (!srcPadOp)
return failure();
- auto dstFillOp = insertOp.dest().getDefiningOp<linalg::FillOp>();
+ if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
+ return failure();
+
+ // Walk back the tensor.insert_slice chain and find the first destination
+ // value at the start of the chain.
+ Value firstDest = insertOp.dest();
+ while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
+ if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
+ return failure();
+
+ // Make sure the range of values accessed are disjoint. Without this, we
+ // cannot fold tensor.pad away.
+ bool disjoint = false;
+ for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
+ // If the dimension has dynamic offset/size, we cannot guarantee
+ // disjoint. So just skip it.
+ if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
+ insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
+ prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
+ continue;
+
+ // Get the range start and end, inclusively for both.
+ int64_t prevStart = prevOp.getStaticOffset(i);
+ int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
+ prevOp.getStaticStride(i);
+ int64_t nextStart = insertOp.getStaticOffset(i);
+ int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
+ insertOp.getStaticStride(i);
+ if (prevEnd < nextStart || nextEnd < prevStart) {
+ disjoint = true;
+ break;
+ }
+ }
+
+ if (!disjoint)
+ break;
+ firstDest = prevOp.dest();
+ }
+
+ // Check whether the first destination is a fill op. For overlapped cases,
+ // this also cannot be true.
+ auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
if (!dstFillOp)
return failure();
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index ca5546bd2697a..a2be638fd8078 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -642,3 +642,108 @@ func @insert_pad_into_fill(%input: tensor<?x?x?xf32>, %low0: index, %low1: index
%0 = tensor.insert_slice %pad into %fill[0, 1, 2] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
return %0: tensor<8x384x384xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @multi_insert_pad_into_fill
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<7x123x124xf32>, %[[A:.+]]: tensor<8x128x128xf32>, %[[OFFSET:.+]]: index)
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[A]] into %[[FILL]][%[[OFFSET]], 0, 0] [8, 128, 128] [1, 1, 1]
+// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[A]] into %[[INSERT0]][0, 128, %[[OFFSET]]] [8, 128, 128] [1, 1, 1]
+// CHECK: tensor.insert_slice %[[INPUT]] into %[[INSERT1]][1, 2, 256] [7, 123, 124] [1, 1, 1]
+func @multi_insert_pad_into_fill(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> {
+ %f0 = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index):
+ tensor.yield %f0 : f32
+ } : tensor<7x123x124xf32> to tensor<8x128x128xf32>
+ %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32>
+ %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32>
+ %0 = tensor.insert_slice %a into %fill[%offset, 0, 0] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+ %1 = tensor.insert_slice %a into %0 [0, 128, %offset][8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+ %2 = tensor.insert_slice %pad into %1 [0, 0, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+ return %2: tensor<8x384x384xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @multi_insert_pad_into_fill_overlap
+func @multi_insert_pad_into_fill_overlap(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> {
+ %f0 = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ // CHECK: tensor.pad
+ %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index):
+ tensor.yield %f0 : f32
+ } : tensor<7x123x124xf32> to tensor<8x128x128xf32>
+ %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32>
+ %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32>
+ %0 = tensor.insert_slice %a into %fill[%offset, 0, 0] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+ %1 = tensor.insert_slice %a into %0 [0, 0, 129] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+ // Range overlap with %1 at dim#3
+ %2 = tensor.insert_slice %pad into %1 [0, 0, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+ return %2: tensor<8x384x384xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @multi_insert_pad_into_fill_overlap
+func @multi_insert_pad_into_fill_overlap(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> {
+ %f0 = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ // CHECK: tensor.pad
+ %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index):
+ tensor.yield %f0 : f32
+ } : tensor<7x123x124xf32> to tensor<8x128x128xf32>
+ %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32>
+ %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32>
+ %0 = tensor.insert_slice %a into %fill[0, 0, %offset] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+ %1 = tensor.insert_slice %a into %0 [0, 128, 255] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+ // Range overlap with %0 at dim#3
+ %2 = tensor.insert_slice %pad into %1 [0, 0, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+ return %2: tensor<8x384x384xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @multi_insert_pad_into_fill
+func @multi_insert_pad_into_fill(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> {
+ %f0 = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ // CHECK-NOT: tensor.pad
+ %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index):
+ tensor.yield %f0 : f32
+ } : tensor<7x123x124xf32> to tensor<8x128x128xf32>
+ %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32>
+ %fill = linalg.fill(%f0, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32>
+ // Overlap btween %0 and %1 is fine but not with %2 is fine.
+ // CHECK-COUNT-3: tensor.insert_slice
+ %0 = tensor.insert_slice %a into %fill[0, 0, %offset] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+ %1 = tensor.insert_slice %a into %0 [0, 1, %offset] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+ %2 = tensor.insert_slice %pad into %1 [0, 256, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+ return %2: tensor<8x384x384xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @multi_insert_pad_into_fill_mismatch
+func @multi_insert_pad_into_fill_mismatch(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> {
+ %f0 = arith.constant 0.0 : f32
+ %f1 = arith.constant 1.0 : f32
+ %c0 = arith.constant 0 : index
+ // CHECK: tensor.pad
+ %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index):
+ tensor.yield %f0 : f32
+ } : tensor<7x123x124xf32> to tensor<8x128x128xf32>
+ %init = linalg.init_tensor [8, 384, 384] : tensor<8x384x384xf32>
+ // Different filling value than padding value.
+ %fill = linalg.fill(%f1, %init) : f32, tensor<8x384x384xf32> -> tensor<8x384x384xf32>
+ %0 = tensor.insert_slice %a into %fill[%offset, 0, 0] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+ %1 = tensor.insert_slice %a into %0 [0, 128, %offset][8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+ %2 = tensor.insert_slice %pad into %1 [0, 0, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
+ return %2: tensor<8x384x384xf32>
+}
More information about the Mlir-commits
mailing list