[Mlir-commits] [mlir] f1595ec - [mlir] Fix bug in UnPackOp tiling implementation causing infinite loop (#113571)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 24 18:32:48 PDT 2024
Author: Max191
Date: 2024-10-24T21:32:45-04:00
New Revision: f1595ecfdce5387e41826fd72ff930a1a39ae398
URL: https://github.com/llvm/llvm-project/commit/f1595ecfdce5387e41826fd72ff930a1a39ae398
DIFF: https://github.com/llvm/llvm-project/commit/f1595ecfdce5387e41826fd72ff930a1a39ae398.diff
LOG: [mlir] Fix bug in UnPackOp tiling implementation causing infinite loop (#113571)
This fixes a bug in the tiling implementation of tensor.unpack that was
causing an infinite loop when certain unpack ops get tiled and fused as
a producer. The tiled implementation of tensor.unpack sometimes needs to
create an additional tensor.extract_slice on the result of the tiled
unpack op, but this slice was getting added to the `generatedSlices` of
the tiling result. The `generatedSlices` are used to find the next
producers to fuse, so it caused an infinite loop of fusing the same
unpack op after it was already in the loop. This fixes the bug by adding
the slice of the source instead of the result.
Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 104d6ae1f9f6b5..ba41904b370991 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -554,13 +554,14 @@ struct UnPackOpTiling
sliceSrcIndices.append(numInnerTiles, zeroAttr);
sliceSrcSizes.append(unpackOp.getMixedTiles());
sliceSrcStrides.append(numInnerTiles, oneAttr);
- Value sliceSource =
+ SmallVector<Operation *> generatedSlices;
+ ExtractSliceOp sliceSource =
b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
sliceSrcSizes, sliceSrcStrides);
+ generatedSlices.push_back(sliceSource);
SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
Value sliceDest;
- SmallVector<Operation *> generatedSlices;
if (isPerfectTilingCase) {
auto destSliceOp = b.create<ExtractSliceOp>(loc, unpackOp.getDest(),
offsets, sizes, destStrides);
@@ -571,7 +572,7 @@ struct UnPackOpTiling
unpackOp.getDestType().getElementType());
}
- SmallVector<Value> tiledOperands = {sliceSource, sliceDest};
+ SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest};
for (auto tile : unpackOp.getInnerTiles())
tiledOperands.push_back(tile);
@@ -586,7 +587,6 @@ struct UnPackOpTiling
auto extractSlice =
b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
resultOffsetsFromDest, sizes, destStrides);
- generatedSlices.push_back(extractSlice);
return TilingResult{
{tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
}
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index 3ea1929e4ed785..5f7663af773a4a 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -587,3 +587,50 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[IF_RESULT]]
// CHECK: scf.yield %[[INSERT_SLICE]]
// CHECK: return %[[FOR_RESULT]]
+
+// -----
+
+func.func @imperfect_unpack_producer_fusion(%source: tensor<1x1x288x8x4xf32>, %dest: tensor<1x2x1152xf32>) -> tensor<1x2x1152xf32> {
+ %0 = tensor.unpack %source
+ outer_dims_perm = [0, 1, 2]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [8, 4] into %dest
+ : tensor<1x1x288x8x4xf32> -> tensor<1x2x1152xf32>
+ %1 = tensor.empty() : tensor<1x2x1152xf32>
+ %cst = arith.constant 1.0 : f32
+ %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%0 : tensor<1x2x1152xf32>)
+ outs(%1 : tensor<1x2x1152xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %7 = arith.addf %in, %cst : f32
+ linalg.yield %7 : f32
+ } -> tensor<1x2x1152xf32>
+ return %2 : tensor<1x2x1152xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.generic"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.structured.fuse %matmul [0, 1, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func @imperfect_unpack_producer_fusion
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x288x8x4xf32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x2x1152xf32>
+// CHECK: %[[FOR_RESULT:.+]] = scf.for{{.*}}iter_args(%[[ITER_ARG:.+]] = {{.*}})
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[SLICE]]
+// CHECK-DAG: %[[UNPACK_SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+// CHECK-DAG: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]]
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[UNPACK_SLICE]]
+// CHECK-SAME: outs(%[[INIT_SLICE]]
+// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER_ARG]]
+// CHECK: scf.yield %[[INSERT_SLICE]]
+// CHECK: return %[[FOR_RESULT]]
More information about the Mlir-commits
mailing list