[Mlir-commits] [mlir] [mlir][scf] fuse `tensor.pack` as consumer (PR #103715)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 15 19:46:43 PDT 2024


https://github.com/Yun-Fly updated https://github.com/llvm/llvm-project/pull/103715

>From 89ef37768846efab3d1c9bdcde3a7111c39d4a57 Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Wed, 14 Aug 2024 01:09:34 -0700
Subject: [PATCH 1/2] fuse `tensor.pack` as consumer

---
 .../Tensor/IR/TensorTilingInterfaceImpl.cpp   | 91 +++++++++++++++++++
 .../tile-and-fuse-consumer.mlir               | 59 ++++++++++++
 2 files changed, 150 insertions(+)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 361340a4e62f2d..51c232ae77fe6c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -246,6 +246,97 @@ struct PackOpTiling
       return failure();
     return tilingResult.value();
   }
+
+  /// Method to return the position of iteration domain tile computed by the
+  /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
+  /// `resultSizes` only cover outer dimensions.
+  LogicalResult getIterationDomainTileFromOperandTile(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      SmallVectorImpl<OpFoldResult> &resultOffsets,
+      SmallVectorImpl<OpFoldResult> &resultSizes) const {
+    auto packOp = cast<PackOp>(op);
+    Location loc = packOp.getLoc();
+
+    SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
+    DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+        packOp.getDimAndTileMapping();
+    for (auto dim : packOp.getOuterDimsPerm()) {
+      if (dimAndTileMapping.count(dim)) {
+        FailureOr<int64_t> cstSize =
+            ValueBoundsConstraintSet::computeConstantBound(
+                presburger::BoundType::UB, sizes[dim],
+                /*stopCondition=*/nullptr, /*closedUB=*/true);
+        std::optional<int64_t> cstInnerSize =
+            getConstantIntValue(dimAndTileMapping[dim]);
+        // Currently only expect perfect tiling cases.
+        if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
+          return failure();
+        }
+
+        using AV = affine::AffineValueExpr;
+        affine::AffineBuilder ab(b, loc);
+        AffineExpr dim0, sym;
+        bindDims(b.getContext(), dim0);
+        bindSymbols(b.getContext(), sym);
+        auto avOffset = AV(dim0).bind(offsets[dim]);
+        auto avSize = AV(dim0).bind(sizes[dim]);
+        auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
+        outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
+        outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
+      } else {
+        outerDimOffsets.push_back(offsets[dim]);
+        outerDimSizes.push_back(sizes[dim]);
+      }
+    }
+
+    resultOffsets = outerDimOffsets;
+    resultSizes = outerDimSizes;
+    return success();
+  }
+
+  /// Method to return the tiled implementation of tensor.pack as a consumer.
+  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+    auto packOp = cast<PackOp>(op);
+    Location loc = packOp.getLoc();
+
+    int64_t inputRank = packOp.getSourceRank();
+    auto oneAttr = b.getI64IntegerAttr(1);
+    SmallVector<OpFoldResult> strides(inputRank, oneAttr);
+
+    SmallVector<Value> tiledOperands;
+    tiledOperands.push_back(b.create<ExtractSliceOp>(loc, packOp.getSource(),
+                                                     offsets, sizes, strides));
+
+    SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
+    if (failed(getIterationDomainTileFromOperandTile(
+            op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
+            outerDimSizes)))
+      return failure();
+
+    SmallVector<OpFoldResult> outputOffsets, outputSizes;
+    if (failed(getResultTilePosition(op, b, 0, outerDimOffsets, outerDimSizes,
+                                     outputOffsets, outputSizes)))
+      return failure();
+
+    strides.append(packOp.getDestRank() - inputRank, oneAttr);
+    auto extractSlice = b.create<ExtractSliceOp>(
+        loc, packOp.getDest(), outputOffsets, outputSizes, strides);
+    tiledOperands.push_back(extractSlice);
+
+    if (auto val = packOp.getPaddingValue())
+      tiledOperands.push_back(val);
+    for (auto tile : packOp.getInnerTiles())
+      tiledOperands.push_back(tile);
+
+    Operation *tiledPackOp = b.create<PackOp>(
+        loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
+
+    return TilingResult{{tiledPackOp},
+                        SmallVector<Value>(tiledPackOp->getResults())};
+  }
 };
 
 struct UnpackTileDimInfo {
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 400b558e37fcda..741dfbfb1cd5c2 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -315,3 +315,62 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:       }
 //      CHECK:   }
 //      CHECK:   return %[[FINAL_RESULT]]#1 :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+    func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
+        %c4 = arith.constant 4 : index
+        %c64 = arith.constant 64 : index
+        %c0 = arith.constant 0 : index
+        %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+            %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+            %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+                ^bb0(%in: f32, %in_16: f32, %out: f32):
+                %13 = arith.mulf %in, %in_16 : f32
+                %14 = arith.addf %out, %13 : f32
+                linalg.yield %14 : f32
+            } -> tensor<32x32xf32>
+            scf.forall.in_parallel {
+                tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+            }
+        }
+        %output = tensor.empty() : tensor<4x32x16xf32>
+        %pack = tensor.pack %1 outer_dims_perm = [0, 1] inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32>
+        return %pack : tensor<4x32x16xf32>
+    }
+}
+  
+module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+        %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+        : (!transform.any_op) -> !transform.any_op
+        %a, %b = transform.test.fuse_consumer %slice_op
+        : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+        transform.yield
+    }
+}
+//      CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+//      CHECK: func.func @fuse_pack_consumer_into_scf_forall(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
+//      CHECK:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<4x32x16xf32>
+//      CHECK:   %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK-SAME:   {
+//      CHECK:      %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:      %[[GENERIC_OUT:.*]] = linalg.generic
+// CHECK-SAME:              outs(%[[GENERIC_OUT_SLICE]] :
+//      CHECK:      %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV1]])
+//      CHECK:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
+//      CHECK:      %[[TILED_PACK_OUT:.*]] = tensor.pack %[[GENERIC_OUT]]
+// CHECK-SAME:                              outer_dims_perm = [0, 1] inner_dims_pos = [0] inner_tiles = [16]
+// CHECK-SAME:                              into %[[TILED_PACK_DEST]]
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]],  %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
+//      CHECK:          tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+//      CHECK:       }
+//      CHECK:   }
+//      CHECK:   return %[[FINAL_RESULT]]#1 :

>From d7fe9c7daf9021a0bb9a9cff57d9c0f0d51ff8a6 Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Thu, 15 Aug 2024 19:46:24 -0700
Subject: [PATCH 2/2] fix comment

---
 mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 51c232ae77fe6c..0b29fb45b03f6a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -255,6 +255,9 @@ struct PackOpTiling
       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
       SmallVectorImpl<OpFoldResult> &resultOffsets,
       SmallVectorImpl<OpFoldResult> &resultSizes) const {
+    if (operandNumber != 0)
+      return failure();
+
     auto packOp = cast<PackOp>(op);
     Location loc = packOp.getLoc();
 
@@ -299,6 +302,9 @@ struct PackOpTiling
   FailureOr<TilingResult> getTiledImplementationFromOperandTile(
       Operation *op, OpBuilder &b, unsigned operandNumber,
       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+    if (operandNumber != 0)
+      return failure();
+
     auto packOp = cast<PackOp>(op);
     Location loc = packOp.getLoc();
 



More information about the Mlir-commits mailing list