[Mlir-commits] [mlir] d6590c1 - [MLIR] Add allow Insert/extract slice option to pack/unpack op (#117340)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 10 09:31:34 PST 2024


Author: Zhuoran Yin
Date: 2024-12-10T11:31:30-06:00
New Revision: d6590c1bcb1b15b3b3f9f0ee6f0a6ff2b10b1e4f

URL: https://github.com/llvm/llvm-project/commit/d6590c1bcb1b15b3b3f9f0ee6f0a6ff2b10b1e4f
DIFF: https://github.com/llvm/llvm-project/commit/d6590c1bcb1b15b3b3f9f0ee6f0a6ff2b10b1e4f.diff

LOG: [MLIR] Add allow Insert/extract slice option to pack/unpack op (#117340)

This PR adds default option below. The new options will come as default
to true and not change the original lowering behavior of pack and unpack
op.
 - lowerPadLikeWithInsertSlice to packOp (with default = true)
 - lowerUnpadLikeWithExtractSlice to unPackOp (with default = true)

The motivation of the PR is finer granular control of the lowering of
pack and unpack Ops. This is useful in particular when we want to
guarantee that there's no additional insertslice and extractslice that
interfere with tiling. With the original lowering pipeline, packOp and
unPackOp may be lowered to insertslice and extractslice when the high
dimensions are unit dimensions and no transpose is invovled. Under such
circumstances, such insert and extract slice ops will block
producer/consumer fusion tile + fuse transforms. With this PR, we will
be able to disable such lowering path and allow consumer fusion to go
through as expected.

Added: 
    mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index dc10f3a1c58ae3..2e713bca24efc5 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -559,7 +559,8 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
     Return handles to the newly produced pad, expand_shape and transpose ops.
   }];
 
-  let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target);
+  let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target,
+                       DefaultValuedAttr<BoolAttr, "true">:$lowerPadLikeWithInsertSlice);
   let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op,
                       Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op,
                       Transform_ConcreteOpType<"linalg.transpose">:$transpose_op);
@@ -599,7 +600,8 @@ def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
     Return handles to the newly produced empty, transpose, collapse_shape and extract_slice ops.
   }];
 
-  let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target);
+  let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target,
+                       DefaultValuedAttr<BoolAttr, "true">:$lowerUnpadLikeWithExtractSlice);
   let results = (outs Transform_ConcreteOpType<"tensor.empty">:$empty_op,
                       Transform_ConcreteOpType<"linalg.transpose">:$transpose_op,
                       Transform_ConcreteOpType<"tensor.collapse_shape">:$collapse_shape_op,

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index f31371ec6a0540..1dc700f22c2027 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1121,7 +1121,8 @@ struct LowerPackResult {
 
 /// Rewrite pack as pad + reshape + transpose.
 FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
-                                     tensor::PackOp packOp);
+                                     tensor::PackOp packOp,
+                                     bool lowerPadLikeWithInsertSlice = true);
 
 struct LowerUnPackOpResult {
   tensor::EmptyOp emptyOp;
@@ -1131,8 +1132,9 @@ struct LowerUnPackOpResult {
 };
 
 /// Rewrite pack as empty + transpose + reshape + extract_slice.
-FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
-                                           tensor::UnPackOp unPackOp);
+FailureOr<LowerUnPackOpResult>
+lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
+            bool lowerUnpadLikeWithExtractSlice = true);
 
 /// Struct to hold the result of a `pack` call.
 struct PackResult {

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index e08be7d2ebd6ae..8839faf4cafb2d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1176,7 +1176,9 @@ DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
     transform::ApplyToEachResultList &transformResults,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
-  FailureOr<LowerPackResult> res = lowerPack(rewriter, target);
+  bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
+  FailureOr<LowerPackResult> res =
+      lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
   if (failed(res)) {
     return mlir::emitSilenceableFailure(target->getLoc())
            << "cannot lower to pad + expand + transpose";
@@ -1196,7 +1198,9 @@ DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
     transform::ApplyToEachResultList &transformResults,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
-  FailureOr<LowerUnPackOpResult> res = lowerUnPack(rewriter, target);
+  bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
+  FailureOr<LowerUnPackOpResult> res =
+      lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
   if (failed(res)) {
     DiagnosedSilenceableFailure diag =
         emitSilenceableError()

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index eeaa70c0b65892..21141f161057e5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -217,7 +217,8 @@ struct PackedOperandsDimList {
 } // namespace
 
 FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
-                                             tensor::PackOp packOp) {
+                                             tensor::PackOp packOp,
+                                             bool lowerPadLikeWithInsertSlice) {
   // 1. Filter out NYI cases.
   auto packedTensorType =
       cast<RankedTensorType>(packOp->getResultTypes().front());
@@ -295,7 +296,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
       llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
       DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
 
-  if (packOp.isLikePad()) {
+  if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
     // Pack ops which operate as simple pads may not produce legal
     // tensor.insert_slice operations when the packed type does not rank reduce
     // to the padded type.
@@ -351,8 +352,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   return LowerPackResult{padOp, reshapeOp, transposeOp};
 }
 
-FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
-                                                   tensor::UnPackOp unPackOp) {
+FailureOr<LowerUnPackOpResult>
+linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
+                    bool lowerUnpadLikeWithExtractSlice) {
   Location loc = unPackOp->getLoc();
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unPackOp);
@@ -362,7 +364,7 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
 
   OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
   auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
-  if (unPackOp.isLikeUnPad()) {
+  if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
     // This unpack is just a plain unpad.
     // Just extract the slice from the higher ranked tensor.
     ArrayRef<int64_t> destShape = destTensorType.getShape();

diff  --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 7aadf190695630..5f8ff36a165786 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -96,6 +96,34 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// This is same as pack_as_pad but since we explicitly added {lowerPadLikeWithInsertSlice = false}, it should not
+// be lowered to insert_slice.
+// CHECK-LABEL: func.func @pack_as_pad_disabled_insert_slice(
+func.func @pack_as_pad_disabled_insert_slice(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
+  %cst_0 = arith.constant 0.0 : f32
+  // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
+  // CHECK-SAME: %[[ARG0:[^:]*]]: tensor<129x47x16x16xf32>
+  //  CHECK-DAG: %[[PAD:.*]] = tensor.pad %[[ARG0]]
+  //  CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]]
+  //      CHECK: %[[PAD_EXPANDED:.*]] = tensor.expand_shape %[[PAD]]
+  //  CHECK-DAG: %[[RES:.*]] = linalg.transpose ins(%[[PAD_EXPANDED]]
+  %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
+    : tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
+  return %pack :  tensor<1x1x1x1x136x64x16x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+      : (!transform.any_op) -> !transform.op<"tensor.pack">
+    transform.structured.lower_pack %pack {lowerPadLikeWithInsertSlice = false}: (!transform.op<"tensor.pack">)
+      -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+      transform.yield
+  }
+}
+
+// -----
+
 // Check that we don't lower the following pack as a pad.
 // Although all the outer most dimensions in the resulting shape are 1s,
 // some of the original dimensions are not part of the inner_dims_pos, hence
@@ -233,6 +261,38 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// This is same as upack_as_pad but since we explicitly added {lowerUnpadLikeWithExtractSlice = false}, it should not
+// be lowered to extract_slice.
+// CHECK-LABEL: func.func @unpack_as_pad_disabled_extract_slice(
+func.func @unpack_as_pad_disabled_extract_slice(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
+  %cst_0 = arith.constant 0.0 : f32
+
+  // tensor.unpack is lowered to tensor.extract_slice + linalg.transpose + tensor.collapse_shape
+  // CHECK-DAG: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
+  // CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
+  //     CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[ARG0]]
+  //     CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[TRANSPOSED]]
+  // CHECK-DAG: %[[RES:.*]] = tensor.extract_slice %[[COLLAPSED]]
+  %pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
+    : tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32>
+  return %pack : tensor<129x47x16x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+      : (!transform.any_op) -> !transform.op<"tensor.unpack">
+    transform.structured.lower_unpack %unpack {lowerUnpadLikeWithExtractSlice = false}: (!transform.op<"tensor.unpack">)
+      -> (!transform.op<"tensor.empty">,
+          !transform.op<"linalg.transpose">,
+          !transform.op<"tensor.collapse_shape">,
+          !transform.op<"tensor.extract_slice">)
+          transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func.func @pack_with_outer_dims_perm(
 func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>,
                                      %dest: tensor<200x4x16x100x16x32xi32>)
@@ -572,7 +632,7 @@ func.func @unpack_fully_dynamic(%source: tensor<?x?x?x?xf32>, %dest: tensor<?x?x
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
-      : (!transform.any_op) -> !transform.op<"tensor.unpack"> 
+      : (!transform.any_op) -> !transform.op<"tensor.unpack">
     transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
           -> (!transform.op<"tensor.empty">,
           !transform.op<"linalg.transpose">,
@@ -627,9 +687,9 @@ module attributes {transform.with_named_sequence} {
 // CHECK-LABEL: @unpack_with_outer_dims_perm
 //  CHECK-SAME: %[[ARG0:.*]]: tensor<32x64xf32>, %[[ARG1:.*]]: tensor<2x4x32x8xf32>
 //       CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<4x8x2x32xf32>
-//       CHECK: %[[TRAN:.*]] = linalg.transpose 
-//  CHECK-SAME:   ins(%[[ARG1]] : tensor<2x4x32x8xf32>) 
-//  CHECK-SAME:   outs(%[[EMPTY]] : tensor<4x8x2x32xf32>) 
+//       CHECK: %[[TRAN:.*]] = linalg.transpose
+//  CHECK-SAME:   ins(%[[ARG1]] : tensor<2x4x32x8xf32>)
+//  CHECK-SAME:   outs(%[[EMPTY]] : tensor<4x8x2x32xf32>)
 //  CHECK-SAME:   permutation = [1, 3, 0, 2]
 //       CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
 //  CHECK-SAME:   : tensor<4x8x2x32xf32> into tensor<32x64xf32>
@@ -638,7 +698,7 @@ module attributes {transform.with_named_sequence} {
 //       CHECK: linalg.copy ins(%[[SLICE]]
 //  CHECK-SAME:   : tensor<32x64xf32>) outs(%[[ARG0]] : tensor<32x64xf32>) -> tensor<32x64xf32>
 func.func @unpack_with_outer_dims_perm(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
-  %unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0] 
+  %unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0]
     inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg0 : tensor<2x4x32x8xf32> -> tensor<32x64xf32>
   return %unpack : tensor<32x64xf32>
 }

diff  --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
new file mode 100644
index 00000000000000..faf7ff9ad7ed09
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir
@@ -0,0 +1,240 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file -canonicalize | FileCheck %s
+
+// For pack op, we use lowerPadLikeWithInsertSlice = false to ensure no insert_slice is generated.
+// This allows linalg.transpose to be fused as a producer operation. In below testcase, linalg.transpose
+// as a producer operation is fused into the scf.forall loop.
+
+module {
+  // CHECK-label: func @fuse_pack_as_producer
+  // CHECK:       scf.forall {{.*}} {
+  // CHECK:         %[[PRODUCER:.*]] = linalg.transpose
+  // CHECK:         linalg.generic {{.*}} ins(%[[PRODUCER]]
+  // CHECK:         scf.forall.in_parallel
+  // CHECK:       }
+  func.func @fuse_pack_as_producer(%src: tensor<128x256xf32>, %other: tensor<4x4x128x256xf32>)
+      -> tensor<4x4x128x256xf32> {
+    %dest = tensor.empty() : tensor<1x1x128x256xf32>
+    %pack = tensor.pack %src inner_dims_pos = [0, 1] inner_tiles = [128, 256]
+        into %dest : tensor<128x256xf32> -> tensor<1x1x128x256xf32>
+
+    %out = tensor.empty() : tensor<4x4x128x256xf32>
+    %res = linalg.generic
+        {indexing_maps = [affine_map<(i, j, k, l) -> (0, 0, k, l)>,
+                          affine_map<(i, j, k, l) -> (i, j, k, l)>,
+                          affine_map<(i, j, k, l) -> (i, j, k, l)>],
+         iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+        ins(%pack, %other: tensor<1x1x128x256xf32>, tensor<4x4x128x256xf32>)
+        outs(%out: tensor<4x4x128x256xf32>) {
+      ^bb0(%pack_elem: f32, %other_elem: f32, %out_elem: f32):
+        %r = arith.addf %pack_elem, %other_elem : f32
+        linalg.yield %r : f32
+    } -> tensor<4x4x128x256xf32>
+
+    return %res : tensor<4x4x128x256xf32>
+  }
+
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+      // Find and lower pack operation.
+      %pack = transform.structured.match ops{["tensor.pack"]} in %arg1
+        : (!transform.any_op) -> !transform.op<"tensor.pack">
+      %paded, %expanded, %transpose = transform.structured.lower_pack %pack {lowerPadLikeWithInsertSlice = false}
+        : (!transform.op<"tensor.pack">)
+        -> (!transform.op<"tensor.pad">,
+            !transform.op<"tensor.expand_shape">,
+            !transform.op<"linalg.transpose">)
+
+      %root = transform.structured.match ops{["linalg.generic"]} in %arg1
+          : (!transform.any_op) -> !transform.any_op
+      // Tile the lialg operation with parallel forall loop tiling [4, 4].
+      %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+      // Fuse the transpose operation into the tiled loop.
+      transform.structured.fuse_into_containing_op %transpose into %forall_op
+          : (!transform.op<"linalg.transpose">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+      transform.yield
+    }
+  }
+}
+
+// -----
+// For pack op, by default lowerPadLikeWithInsertSlice = true, which generates insert_slice and blocks fusion.
+// In below testcase, tensor.insert_slice as a producer operation cannot be fused into the scf.forall loop.
+
+module {
+  // CHECK-label: func @fuse_pack_as_producer_blocked_by_insert_slice
+  // CHECK:       %[[PRODUCER:.*]] = tensor.insert_slice
+  // CHECK:       scf.forall {{.*}} {
+  // CHECK:         linalg.generic {{.*}} ins(%[[PRODUCER]]
+  // CHECK:         scf.forall.in_parallel
+  // CHECK:       }
+  func.func @fuse_pack_as_producer_blocked_by_insert_slice(%src: tensor<128x256xf32>, %other: tensor<4x4x128x256xf32>)
+      -> tensor<4x4x128x256xf32> {
+    %dest = tensor.empty() : tensor<1x1x128x256xf32>
+    %pack = tensor.pack %src inner_dims_pos = [0, 1] inner_tiles = [128, 256]
+        into %dest : tensor<128x256xf32> -> tensor<1x1x128x256xf32>
+
+    %out = tensor.empty() : tensor<4x4x128x256xf32>
+    %res = linalg.generic
+        {indexing_maps = [affine_map<(i, j, k, l) -> (0, 0, k, l)>,
+                          affine_map<(i, j, k, l) -> (i, j, k, l)>,
+                          affine_map<(i, j, k, l) -> (i, j, k, l)>],
+         iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+        ins(%pack, %other: tensor<1x1x128x256xf32>, tensor<4x4x128x256xf32>)
+        outs(%out: tensor<4x4x128x256xf32>) {
+      ^bb0(%pack_elem: f32, %other_elem: f32, %out_elem: f32):
+        %r = arith.addf %pack_elem, %other_elem : f32
+        linalg.yield %r : f32
+    } -> tensor<4x4x128x256xf32>
+
+    return %res : tensor<4x4x128x256xf32>
+  }
+
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+      // Find and lower pack operation.
+      %pack = transform.structured.match ops{["tensor.pack"]} in %arg1
+        : (!transform.any_op) -> !transform.op<"tensor.pack">
+      %paded, %expanded, %transpose = transform.structured.lower_pack %pack
+        : (!transform.op<"tensor.pack">)
+        -> (!transform.op<"tensor.pad">,
+            !transform.op<"tensor.expand_shape">,
+            !transform.op<"linalg.transpose">)
+
+      %root = transform.structured.match ops{["linalg.generic"]} in %arg1
+          : (!transform.any_op) -> !transform.any_op
+      // Tile the lialg operation with parallel forall loop tiling [4, 4].
+      %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+      // Fuse the transpose operation into the tiled loop.
+      transform.structured.fuse_into_containing_op %transpose into %forall_op
+          : (!transform.op<"linalg.transpose">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+      transform.yield
+    }
+  }
+}
+
+// -----
+// For unpack op, we use lowerUnpadLikeWithExtractSlice = false to ensure no extract_slice is generated.
+// This allows linalg.transpose to be fused as a consumer operation. In below testcase, linalg.transpose
+// as a consumer operation is fused into the scf.forall loop.
+module {
+  // CHECK-label: func @fuse_unpack_as_consumer
+  // CHECK:       scf.forall {{.*}} {
+  // CHECK:         %[[CONSUMER:.*]] = linalg.generic
+  // CHECK:         linalg.transpose ins(%[[CONSUMER]]
+  // CHECK:         scf.forall.in_parallel
+  // CHECK:       }
+  func.func @fuse_unpack_as_consumer(%src: tensor<4x4x128x256xf32>, %other: tensor<4x4x128x256xf32>)
+      -> tensor<128x256xf32> {
+    %out = tensor.empty() : tensor<1x1x128x256xf32>
+    %res = linalg.generic
+        {indexing_maps = [affine_map<(i, j, k, l) -> (i, j, k, l)>,
+                          affine_map<(i, j, k, l) -> (i, j, k, l)>,
+                          affine_map<(i, j, k, l) -> (0, 0, k, l)>],
+         iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+        ins(%src, %other: tensor<4x4x128x256xf32>, tensor<4x4x128x256xf32>)
+        outs(%out: tensor<1x1x128x256xf32>) {
+      ^bb0(%unpack_elem: f32, %other_elem: f32, %out_elem: f32):
+        %r = arith.addf %unpack_elem, %other_elem : f32
+        linalg.yield %r : f32
+    } -> tensor<1x1x128x256xf32>
+
+    %dest = tensor.empty() : tensor<128x256xf32>
+    %unpack = tensor.unpack %res inner_dims_pos = [0, 1] inner_tiles = [128, 256]
+        into %dest : tensor<1x1x128x256xf32> -> tensor<128x256xf32>
+
+    return %unpack : tensor<128x256xf32>
+  }
+
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+      // Find and lower unpack operation.
+      %unpack = transform.structured.match ops{["tensor.unpack"]} in %arg1
+          : (!transform.any_op) -> !transform.op<"tensor.unpack">
+      transform.structured.lower_unpack %unpack {lowerUnpadLikeWithExtractSlice = false}
+        : (!transform.op<"tensor.unpack">)
+        -> (!transform.op<"tensor.empty">,
+            !transform.op<"linalg.transpose">,
+            !transform.op<"tensor.collapse_shape">,
+            !transform.op<"tensor.extract_slice">)
+
+      %root = transform.structured.match ops{["linalg.generic"]} in %arg1
+          : (!transform.any_op) -> !transform.any_op
+      // Tile the lialg operation with parallel forall loop tiling [4, 4].
+      %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+      // Fuse the consumer operation into the tiled loop.
+      %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
+          : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
+      transform.test.fuse_consumer %slice_op
+        : (!transform.op<"tensor.parallel_insert_slice">) -> (!transform.any_op, !transform.any_op)
+      transform.yield
+    }
+  }
+}
+
+// -----
+// For unpack op, by default lowerUnpadLikeWithExtractSlice = true, which generates extract_slice and blocks fusion.
+// In below testcase, tensor.extract_slice as a consumer operation cannot be fused into the scf.forall loop.
+module {
+  // CHECK-label: func @fuse_unpack_as_consumer_blocked_by_extract_slice
+  // CHECK:       %[[CONSUMER:.*]] = scf.forall {{.*}} {
+  // CHECK:         %[[ADDF:.*]] = linalg.generic
+  // CHECK:         scf.forall.in_parallel
+  // CHECK:           tensor.parallel_insert_slice %[[ADDF]]
+  // CHECK:       }
+  // CHECK:       tensor.extract_slice %[[CONSUMER]]
+  func.func @fuse_unpack_as_consumer_blocked_by_extract_slice(%src: tensor<4x4x128x256xf32>, %other: tensor<4x4x128x256xf32>)
+      -> tensor<128x256xf32> {
+    %out = tensor.empty() : tensor<1x1x128x256xf32>
+    %res = linalg.generic
+        {indexing_maps = [affine_map<(i, j, k, l) -> (i, j, k, l)>,
+                          affine_map<(i, j, k, l) -> (i, j, k, l)>,
+                          affine_map<(i, j, k, l) -> (0, 0, k, l)>],
+         iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+        ins(%src, %other: tensor<4x4x128x256xf32>, tensor<4x4x128x256xf32>)
+        outs(%out: tensor<1x1x128x256xf32>) {
+      ^bb0(%unpack_elem: f32, %other_elem: f32, %out_elem: f32):
+        %r = arith.addf %unpack_elem, %other_elem : f32
+        linalg.yield %r : f32
+    } -> tensor<1x1x128x256xf32>
+
+    %dest = tensor.empty() : tensor<128x256xf32>
+    %unpack = tensor.unpack %res inner_dims_pos = [0, 1] inner_tiles = [128, 256]
+        into %dest : tensor<1x1x128x256xf32> -> tensor<128x256xf32>
+
+    return %unpack : tensor<128x256xf32>
+  }
+
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+      // Find and lower unpack operation.
+      %unpack = transform.structured.match ops{["tensor.unpack"]} in %arg1
+          : (!transform.any_op) -> !transform.op<"tensor.unpack">
+      transform.structured.lower_unpack %unpack
+        : (!transform.op<"tensor.unpack">)
+        -> (!transform.op<"tensor.empty">,
+            !transform.op<"linalg.transpose">,
+            !transform.op<"tensor.collapse_shape">,
+            !transform.op<"tensor.extract_slice">)
+
+      %root = transform.structured.match ops{["linalg.generic"]} in %arg1
+          : (!transform.any_op) -> !transform.any_op
+      // Tile the lialg operation with parallel forall loop tiling [4, 4].
+      %tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4]
+          : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+      // Fuse the consumer operation into the tiled loop.
+      %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
+          : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
+      // Note that we cannot apply transform.test.fuse_consumer here because the extract_slice
+      // is not qualified consumer operation. Forcing this will yeild "could not fetch consumer
+      // to fuse" error.
+      transform.yield
+    }
+  }
+}


        


More information about the Mlir-commits mailing list