[Mlir-commits] [mlir] f1595ec - [mlir] Fix bug in UnPackOp tiling implementation causing infinite loop (#113571)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 24 18:32:48 PDT 2024


Author: Max191
Date: 2024-10-24T21:32:45-04:00
New Revision: f1595ecfdce5387e41826fd72ff930a1a39ae398

URL: https://github.com/llvm/llvm-project/commit/f1595ecfdce5387e41826fd72ff930a1a39ae398
DIFF: https://github.com/llvm/llvm-project/commit/f1595ecfdce5387e41826fd72ff930a1a39ae398.diff

LOG: [mlir] Fix bug in UnPackOp tiling implementation causing infinite loop (#113571)

This fixes a bug in the tiling implementation of tensor.unpack that was
causing an infinite loop when certain unpack ops get tiled and fused as
a producer. The tiled implementation of tensor.unpack sometimes needs to
create an additional tensor.extract_slice on the result of the tiled
unpack op, but this slice was getting added to the `generatedSlices` of
the tiling result. The `generatedSlices` are used to find the next
producers to fuse, so it caused an infinite loop of fusing the same
unpack op after it was already in the loop. This fixes the bug by adding
the slice of the source instead of the result.

Signed-off-by: Max Dawkins <max.dawkins at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
    mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 104d6ae1f9f6b5..ba41904b370991 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -554,13 +554,14 @@ struct UnPackOpTiling
     sliceSrcIndices.append(numInnerTiles, zeroAttr);
     sliceSrcSizes.append(unpackOp.getMixedTiles());
     sliceSrcStrides.append(numInnerTiles, oneAttr);
-    Value sliceSource =
+    SmallVector<Operation *> generatedSlices;
+    ExtractSliceOp sliceSource =
         b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
                                  sliceSrcSizes, sliceSrcStrides);
+    generatedSlices.push_back(sliceSource);
 
     SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
     Value sliceDest;
-    SmallVector<Operation *> generatedSlices;
     if (isPerfectTilingCase) {
       auto destSliceOp = b.create<ExtractSliceOp>(loc, unpackOp.getDest(),
                                                   offsets, sizes, destStrides);
@@ -571,7 +572,7 @@ struct UnPackOpTiling
                                     unpackOp.getDestType().getElementType());
     }
 
-    SmallVector<Value> tiledOperands = {sliceSource, sliceDest};
+    SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest};
     for (auto tile : unpackOp.getInnerTiles())
       tiledOperands.push_back(tile);
 
@@ -586,7 +587,6 @@ struct UnPackOpTiling
     auto extractSlice =
         b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
                                  resultOffsetsFromDest, sizes, destStrides);
-    generatedSlices.push_back(extractSlice);
     return TilingResult{
         {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
   }

diff  --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index 3ea1929e4ed785..5f7663af773a4a 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -587,3 +587,50 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:     %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[IF_RESULT]]
 //       CHECK:     scf.yield %[[INSERT_SLICE]]
 //       CHECK:   return %[[FOR_RESULT]]
+
+// -----
+
+func.func @imperfect_unpack_producer_fusion(%source: tensor<1x1x288x8x4xf32>, %dest: tensor<1x2x1152xf32>) -> tensor<1x2x1152xf32> {
+  %0 = tensor.unpack %source
+      outer_dims_perm = [0, 1, 2]
+      inner_dims_pos = [1, 2]
+      inner_tiles = [8, 4] into %dest
+      : tensor<1x1x288x8x4xf32> -> tensor<1x2x1152xf32>
+  %1 = tensor.empty() : tensor<1x2x1152xf32>
+  %cst = arith.constant 1.0 : f32
+  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+                                        affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+                       iterator_types = ["parallel", "parallel", "parallel"]}
+                       ins(%0 : tensor<1x2x1152xf32>)
+                       outs(%1 : tensor<1x2x1152xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %7 = arith.addf %in, %cst : f32
+    linalg.yield %7 : f32
+  } -> tensor<1x2x1152xf32>
+  return %2 : tensor<1x2x1152xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.generic"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.structured.fuse %matmul [0, 1, 0]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func @imperfect_unpack_producer_fusion
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<1x1x288x8x4xf32>
+//  CHECK-SAME:     %[[ARG1:.+]]: tensor<1x2x1152xf32>
+//       CHECK:   %[[FOR_RESULT:.+]] = scf.for{{.*}}iter_args(%[[ITER_ARG:.+]] = {{.*}})
+//       CHECK:     %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+//       CHECK:     %[[UNPACK:.+]] = tensor.unpack %[[SLICE]]
+//   CHECK-DAG:     %[[UNPACK_SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+//   CHECK-DAG:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ITER_ARG]]
+//       CHECK:     %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[UNPACK_SLICE]]
+//  CHECK-SAME:         outs(%[[INIT_SLICE]]
+//       CHECK:     %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER_ARG]]
+//       CHECK:     scf.yield %[[INSERT_SLICE]]
+//       CHECK:   return %[[FOR_RESULT]]


        


More information about the Mlir-commits mailing list