[Mlir-commits] [mlir] [mlir] Fix bug in UnPackOp tiling implementation causing infinite loop (PR #113571)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 24 07:09:36 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: None (Max191)

<details>
<summary>Changes</summary>

This fixes a bug in the tiling implementation of tensor.unpack that was causing an infinite loop when certain unpack ops get tiled and fused as a producer. The tiled implementation of tensor.unpack sometimes needs to create an additional tensor.extract_slice on the result of the tiled unpack op, but this slice was getting added to the `generatedSlices` of the tiling result. The `generatedSlices` are used to find the next producers to fuse, so it caused an infinite loop of fusing the same unpack op after it was already in the loop. This fixes the bug by adding the slice of the source instead of the result.

---
Full diff: https://github.com/llvm/llvm-project/pull/113571.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (+4-4) 
- (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir (+47) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 104d6ae1f9f6b5..ba41904b370991 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -554,13 +554,14 @@ struct UnPackOpTiling
     sliceSrcIndices.append(numInnerTiles, zeroAttr);
     sliceSrcSizes.append(unpackOp.getMixedTiles());
     sliceSrcStrides.append(numInnerTiles, oneAttr);
-    Value sliceSource =
+    SmallVector<Operation *> generatedSlices;
+    ExtractSliceOp sliceSource =
         b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
                                  sliceSrcSizes, sliceSrcStrides);
+    generatedSlices.push_back(sliceSource);
 
     SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
     Value sliceDest;
-    SmallVector<Operation *> generatedSlices;
     if (isPerfectTilingCase) {
       auto destSliceOp = b.create<ExtractSliceOp>(loc, unpackOp.getDest(),
                                                   offsets, sizes, destStrides);
@@ -571,7 +572,7 @@ struct UnPackOpTiling
                                     unpackOp.getDestType().getElementType());
     }
 
-    SmallVector<Value> tiledOperands = {sliceSource, sliceDest};
+    SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest};
     for (auto tile : unpackOp.getInnerTiles())
       tiledOperands.push_back(tile);
 
@@ -586,7 +587,6 @@ struct UnPackOpTiling
     auto extractSlice =
         b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
                                  resultOffsetsFromDest, sizes, destStrides);
-    generatedSlices.push_back(extractSlice);
     return TilingResult{
         {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
   }
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index 3ea1929e4ed785..5f7663af773a4a 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -587,3 +587,50 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:     %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[IF_RESULT]]
 //       CHECK:     scf.yield %[[INSERT_SLICE]]
 //       CHECK:   return %[[FOR_RESULT]]
+
+// -----
+
+func.func @imperfect_unpack_producer_fusion(%source: tensor<1x1x288x8x4xf32>, %dest: tensor<1x2x1152xf32>) -> tensor<1x2x1152xf32> {
+  %0 = tensor.unpack %source
+      outer_dims_perm = [0, 1, 2]
+      inner_dims_pos = [1, 2]
+      inner_tiles = [8, 4] into %dest
+      : tensor<1x1x288x8x4xf32> -> tensor<1x2x1152xf32>
+  %1 = tensor.empty() : tensor<1x2x1152xf32>
+  %cst = arith.constant 1.0 : f32
+  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+                                        affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+                       iterator_types = ["parallel", "parallel", "parallel"]}
+                       ins(%0 : tensor<1x2x1152xf32>)
+                       outs(%1 : tensor<1x2x1152xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %7 = arith.addf %in, %cst : f32
+    linalg.yield %7 : f32
+  } -> tensor<1x2x1152xf32>
+  return %2 : tensor<1x2x1152xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.generic"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.structured.fuse %matmul [0, 1, 0]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func @imperfect_unpack_producer_fusion
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<1x1x288x8x4xf32>
+//  CHECK-SAME:     %[[ARG1:.+]]: tensor<1x2x1152xf32>
+//       CHECK:   %[[FOR_RESULT:.+]] = scf.for{{.*}}iter_args(%[[ITER_ARG:.+]] = {{.*}})
+//       CHECK:     %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+//       CHECK:     %[[UNPACK:.+]] = tensor.unpack %[[SLICE]]
+//   CHECK-DAG:     %[[UNPACK_SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+//   CHECK-DAG:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]]
+//       CHECK:     %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[UNPACK_SLICE]]
+//  CHECK-SAME:         outs(%[[INIT_SLICE]]
+//       CHECK:     %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER_ARG]]
+//       CHECK:     scf.yield %[[INSERT_SLICE]]
+//       CHECK:   return %[[FOR_RESULT]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/113571


More information about the Mlir-commits mailing list