[Mlir-commits] [mlir] [mlir] Fix bug in UnPackOp tiling implementation causing infinite loop (PR #113571)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 24 07:09:02 PDT 2024
https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/113571
This fixes a bug in the tiling implementation of tensor.unpack that was causing an infinite loop when certain unpack ops get tiled and fused as a producer. The tiled implementation of tensor.unpack sometimes needs to create an additional tensor.extract_slice on the result of the tiled unpack op, but this slice was getting added to the `generatedSlices` of the tiling result. The `generatedSlices` are used to find the next producers to fuse, so it caused an infinite loop of fusing the same unpack op after it was already in the loop. This fixes the bug by adding the slice of the source instead of the result.
>From 92f50ad99f6c4dc3b4f96a72e378122c2d4543e0 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 24 Oct 2024 09:59:30 -0400
Subject: [PATCH] [mlir] Fix bug in UnPackOp tiling implementation causing
infinite loop
Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
.../Tensor/IR/TensorTilingInterfaceImpl.cpp | 8 ++--
.../tile-and-fuse-using-interface.mlir | 47 +++++++++++++++++++
2 files changed, 51 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 104d6ae1f9f6b5..ba41904b370991 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -554,13 +554,14 @@ struct UnPackOpTiling
sliceSrcIndices.append(numInnerTiles, zeroAttr);
sliceSrcSizes.append(unpackOp.getMixedTiles());
sliceSrcStrides.append(numInnerTiles, oneAttr);
- Value sliceSource =
+ SmallVector<Operation *> generatedSlices;
+ ExtractSliceOp sliceSource =
b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
sliceSrcSizes, sliceSrcStrides);
+ generatedSlices.push_back(sliceSource);
SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
Value sliceDest;
- SmallVector<Operation *> generatedSlices;
if (isPerfectTilingCase) {
auto destSliceOp = b.create<ExtractSliceOp>(loc, unpackOp.getDest(),
offsets, sizes, destStrides);
@@ -571,7 +572,7 @@ struct UnPackOpTiling
unpackOp.getDestType().getElementType());
}
- SmallVector<Value> tiledOperands = {sliceSource, sliceDest};
+ SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest};
for (auto tile : unpackOp.getInnerTiles())
tiledOperands.push_back(tile);
@@ -586,7 +587,6 @@ struct UnPackOpTiling
auto extractSlice =
b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
resultOffsetsFromDest, sizes, destStrides);
- generatedSlices.push_back(extractSlice);
return TilingResult{
{tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
}
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index 3ea1929e4ed785..5f7663af773a4a 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -587,3 +587,50 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[IF_RESULT]]
// CHECK: scf.yield %[[INSERT_SLICE]]
// CHECK: return %[[FOR_RESULT]]
+
+// -----
+
+func.func @imperfect_unpack_producer_fusion(%source: tensor<1x1x288x8x4xf32>, %dest: tensor<1x2x1152xf32>) -> tensor<1x2x1152xf32> {
+ %0 = tensor.unpack %source
+ outer_dims_perm = [0, 1, 2]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [8, 4] into %dest
+ : tensor<1x1x288x8x4xf32> -> tensor<1x2x1152xf32>
+ %1 = tensor.empty() : tensor<1x2x1152xf32>
+ %cst = arith.constant 1.0 : f32
+ %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%0 : tensor<1x2x1152xf32>)
+ outs(%1 : tensor<1x2x1152xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %7 = arith.addf %in, %cst : f32
+ linalg.yield %7 : f32
+ } -> tensor<1x2x1152xf32>
+ return %2 : tensor<1x2x1152xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.generic"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.structured.fuse %matmul [0, 1, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func @imperfect_unpack_producer_fusion
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x288x8x4xf32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x2x1152xf32>
+// CHECK: %[[FOR_RESULT:.+]] = scf.for{{.*}}iter_args(%[[ITER_ARG:.+]] = {{.*}})
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[SLICE]]
+// CHECK-DAG: %[[UNPACK_SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]]
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[UNPACK_SLICE]]
+// CHECK-SAME: outs(%[[INIT_SLICE]]
+// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER_ARG]]
+// CHECK: scf.yield %[[INSERT_SLICE]]
+// CHECK: return %[[FOR_RESULT]]
More information about the Mlir-commits
mailing list