[Mlir-commits] [mlir] 0d03ba6 - [mlir][tensor] Implement	TilingInterface for tensor.pack op.
    Hanhan Wang 
    llvmlistbot at llvm.org
       
    Mon Dec  5 14:00:18 PST 2022
    
    
  
Author: Hanhan Wang
Date: 2022-12-05T14:00:10-08:00
New Revision: 0d03ba62c55f5fc1edf28ccffe9dd4ffa3edd4d0
URL: https://github.com/llvm/llvm-project/commit/0d03ba62c55f5fc1edf28ccffe9dd4ffa3edd4d0
DIFF: https://github.com/llvm/llvm-project/commit/0d03ba62c55f5fc1edf28ccffe9dd4ffa3edd4d0.diff
LOG: [mlir][tensor] Implement TilingInterface for tensor.pack op.
We can compute the offsets and sizes for the slice of input because the
iteration domain is defined over outer loops. If the dimension is tiled,
the i-th index is the product of offset_i and inner_tile_i.
Different from tiling a pad op, we do not have to deal with reading zero
data from input. Because the tiling sizes are indicated to packed outer
dimensions. We will read either the entire tile or partial tile for each
packed tile. The scf.if and tensor.generate ops are not needed in this
context.
Co-authored-by: Lorenzo Chelini <l.chelini at icloud.com>
Reviewed By: rengolin, mravishankar
Differential Revision: https://reviews.llvm.org/D138631
Added: 
    mlir/test/Dialect/Tensor/tiling.mlir
Modified: 
    mlir/include/mlir/Dialect/Affine/Utils.h
    mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
    mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed: 
    
################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index d5f02845715bd..29e10d808ccd1 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -14,6 +14,7 @@
 #define MLIR_DIALECT_AFFINE_UTILS_H
 
 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 
 namespace mlir {
 
@@ -328,6 +329,56 @@ FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
 /// that would change the read within `memOp`.
 template <typename EffectType, typename T>
 bool hasNoInterveningEffect(Operation *start, T memOp);
+
+struct AffineValueExpr {
+  explicit AffineValueExpr(AffineExpr e) : e(e) {}
+  AffineValueExpr bind(Value v) {
+    this->v = v;
+    return *this;
+  }
+  AffineValueExpr bind(OpFoldResult v) {
+    this->v = v;
+    return *this;
+  }
+  operator AffineExpr() const { return e; }
+  operator OpFoldResult() const { return v; }
+  AffineExpr e;
+  OpFoldResult v;
+};
+
+/// Helper struct to build simple AffineValueExprs with minimal type inference
+/// support.
+struct AffineBuilder {
+  AffineBuilder(OpBuilder &b, Location loc) : b(b), loc(loc) {}
+  OpFoldResult add(AffineValueExpr lhs, AffineValueExpr rhs) {
+    return makeComposedFoldedAffineApply(b, loc, {lhs.e + rhs.e}, {lhs, rhs});
+  }
+  OpFoldResult sub(AffineValueExpr lhs, AffineValueExpr rhs) {
+    return makeComposedFoldedAffineApply(b, loc, {lhs.e - rhs.e}, {lhs, rhs});
+  }
+  OpFoldResult mul(AffineValueExpr lhs, AffineValueExpr rhs) {
+    return makeComposedFoldedAffineApply(b, loc, {lhs.e * rhs.e}, {lhs, rhs});
+  }
+  OpFoldResult ceil(AffineValueExpr lhs, AffineValueExpr rhs) {
+    return makeComposedFoldedAffineApply(b, loc, {lhs.e.ceilDiv(rhs.e)},
+                                         {lhs, rhs});
+  }
+  OpFoldResult min(ArrayRef<OpFoldResult> vals) {
+    return makeComposedFoldedAffineMin(
+        b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), b.getContext()),
+        vals);
+  }
+  OpFoldResult max(ArrayRef<OpFoldResult> vals) {
+    return makeComposedFoldedAffineMax(
+        b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), b.getContext()),
+        vals);
+  }
+
+private:
+  OpBuilder &b;
+  Location loc;
+};
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_AFFINE_UTILS_H
diff  --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
index 8be5f6052a929..0c04762b8a40a 100644
--- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
@@ -57,10 +57,12 @@ add_mlir_dialect_library(MLIRTensorTilingInterfaceImpl
 
   LINK_LIBS PUBLIC
   MLIRAffineDialect
+  MLIRDialectUtils
   MLIRIR
   MLIRLinalgDialect
   MLIRSCFDialect
   MLIRSupport
   MLIRTensorDialect
+  MLIRTensorUtils
   MLIRTilingInterface
   )
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 49d31d2f1e487..b8210aabd782c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -8,10 +8,13 @@
 
 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Interfaces/TilingInterface.h"
 
 using namespace mlir;
@@ -68,6 +71,145 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
   }
 };
 
+struct PackOpTiling
+    : public TilingInterface::ExternalModel<PackOpTiling, PackOp> {
+
+  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+    // Note that here we only consider untiled dimensions and outer tiled data
+    // dimensions, the inner tiled data dimensions are materialized when
+    // building the body of the operation.
+    auto packOp = cast<PackOp>(op);
+    SmallVector<utils::IteratorType> iteratorTypes(
+        packOp.getSourceRank(), utils::IteratorType::parallel);
+    return iteratorTypes;
+  }
+
+  SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
+    OpBuilder::InsertionGuard guard(b);
+    auto packOp = cast<PackOp>(op);
+    Location loc = packOp.getLoc();
+    int64_t rank = packOp.getSourceRank();
+    Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+    Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+    ReifiedRankedShapedTypeDims resultShape;
+    (void)packOp.reifyResultShapes(b, resultShape);
+    SmallVector<Range> loopRanges(rank);
+    for (auto dim : llvm::seq<int64_t>(0, rank)) {
+      loopRanges[dim].offset = zero;
+      loopRanges[dim].stride = one;
+      loopRanges[dim].size = resultShape[0][dim];
+    }
+    return loopRanges;
+  }
+
+  SmallVector<Operation *>
+  getTiledImplementation(Operation *op, OpBuilder &b,
+                         ArrayRef<OpFoldResult> offsets,
+                         ArrayRef<OpFoldResult> sizes) const {
+    auto packOp = cast<PackOp>(op);
+    Location loc = packOp.getLoc();
+
+    // The tiling is applied on interchanged dimensions. We have to undo the
+    // interchange to map sizes and offsets to the original input.
+    int64_t inputRank = packOp.getSourceRank();
+    ArrayRef<int64_t> dimsToOuterBlock(packOp.getOuterDimsPerm());
+    SmallVector<OpFoldResult> origOffsets(offsets.begin(), offsets.end());
+    SmallVector<OpFoldResult> origSizes(sizes.begin(), sizes.end());
+    if (!dimsToOuterBlock.empty()) {
+      SmallVector<int64_t> inversedPerm =
+          invertPermutationVector(dimsToOuterBlock);
+      applyPermutationToVector<OpFoldResult>(origOffsets, inversedPerm);
+      applyPermutationToVector<OpFoldResult>(origSizes, inversedPerm);
+    }
+
+    DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+        packOp.getDimAndTileMapping();
+    SmallVector<OpFoldResult> srcDimValues =
+        tensor::createDimValues(b, loc, packOp.getSource());
+    SmallVector<OpFoldResult> inputIndices, inputSizes;
+    for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
+      using AV = AffineValueExpr;
+      AffineBuilder ab(b, loc);
+      AffineExpr dim0, dim1, sym;
+      bindDims(b.getContext(), dim0, dim1);
+      bindSymbols(b.getContext(), sym);
+      if (dimAndTileMapping.count(dim)) {
+        // If the data dimension is tiled, the i-th index is the product of
+        // offset_i and tile_i, and the i-th size is the product of sizes_i and
+        // tile_i.
+        auto avOffset = AV(dim0).bind(origOffsets[dim]);
+        auto avSize = AV(dim0).bind(origSizes[dim]);
+        auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
+        inputIndices.push_back(ab.mul(avOffset, avTileSize));
+        inputSizes.push_back(ab.mul(avSize, avTileSize));
+      } else {
+        inputIndices.push_back(origOffsets[dim]);
+        inputSizes.push_back(origSizes[dim]);
+      }
+
+      // Limit the size of the input operand for incomplete tiles.
+      OpFoldResult dimSize = srcDimValues[dim];
+      auto avDimSize = AV(dim0).bind(dimSize);
+      auto avInputIdx = AV(dim1).bind(inputIndices.back());
+      inputSizes.back() =
+          ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
+    }
+
+    auto oneAttr = b.getI64IntegerAttr(1);
+    SmallVector<OpFoldResult> strides(inputRank, oneAttr);
+
+    SmallVector<Value> tiledOperands;
+    tiledOperands.push_back(b.create<ExtractSliceOp>(
+        loc, packOp.getSource(), inputIndices, inputSizes, strides));
+
+    SmallVector<OpFoldResult> outputOffsets, outputSizes;
+    if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
+                                     outputSizes)))
+      return {};
+
+    strides.append(packOp.getDestRank() - inputRank, oneAttr);
+    auto extractSlice = b.create<ExtractSliceOp>(
+        loc, packOp.getDest(), outputOffsets, outputSizes, strides);
+    tiledOperands.push_back(extractSlice);
+
+    if (auto val = packOp.getPaddingValue())
+      tiledOperands.push_back(val);
+    for (auto tile : packOp.getInnerTiles())
+      tiledOperands.push_back(tile);
+
+    Operation *tiledPackOp = b.create<PackOp>(
+        loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
+
+    return {tiledPackOp};
+  }
+
+  LogicalResult
+  getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
+                        ArrayRef<OpFoldResult> offsets,
+                        ArrayRef<OpFoldResult> sizes,
+                        SmallVector<OpFoldResult> &resultOffsets,
+                        SmallVector<OpFoldResult> &resultSizes) const {
+    // The iteration domain is over outer dimensions of packed layout. In this
+    // context, the outer dimensions of `resultOffsets` are `offsets`. The
+    // inner dimensions of `resultOffsets` are zeros because tiling is not
+    // applied to them.
+    auto packOp = cast<PackOp>(op);
+    int64_t inputRank = packOp.getSourceRank();
+    int64_t outputRank = packOp.getDestRank();
+    auto zeroAttr = b.getI64IntegerAttr(0);
+    resultOffsets.assign(offsets.begin(), offsets.end());
+    resultOffsets.append(outputRank - inputRank, zeroAttr);
+
+    ReifiedRankedShapedTypeDims outputShape;
+    (void)packOp.reifyResultShapes(b, outputShape);
+    resultSizes.assign(sizes.begin(), sizes.end());
+    for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
+      resultSizes.push_back(getAsOpFoldResult(outputShape[0][dataTileDim]));
+
+    return success();
+  }
+};
+
 } // namespace
 
 Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
@@ -282,5 +424,6 @@ void mlir::tensor::registerTilingInterfaceExternalModels(
     DialectRegistry ®istry) {
   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
     tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
+    tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
   });
 }
diff  --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir
new file mode 100644
index 0000000000000..612367916bb4c
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/tiling.mlir
@@ -0,0 +1,214 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -canonicalize -cse -split-input-file | FileCheck %s
+
+// CHECK-DAG:   #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 32)>
+// CHECK-DAG:   #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -32 + 128, 64)>
+// CHECK-DAG:   #[[MAP2:.+]] = affine_map<(d0) -> (d0 * -32 + 256, 128)>
+// CHECK:       func.func @NC_to_NCnc
+// CHECK-SAME:    %[[IN:.*]]: tensor<128x256xf32>,
+// CHECK-SAME:    %[[OUT:.*]]: tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> {
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
+// CHECK:         %[[RES0:.*]] = scf.for %[[N:.*]] = %[[C0]] to %[[C4]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor<4x8x32x32xf32>) {
+// CHECK:           %[[RES1:.+]] = scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor<4x8x32x32xf32>) {
+// CHECK-DAG:         %[[IN_N:.+]] = affine.apply #[[MAP0]](%[[N]])
+// CHECK-DAG:         %[[IN_N_SZ:.*]] = affine.min #[[MAP1]]
+// CHECK-DAG:         %[[IN_C:.+]] = affine.apply #[[MAP0]](%[[C]])
+// CHECK-DAG:         %[[IN_C_SZ:.*]] = affine.min #[[MAP2]]
+// CHECK:             %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_N]], %[[IN_C]]] [%[[IN_N_SZ]], %[[IN_C_SZ]]] [1, 1] : tensor<128x256xf32> to tensor<?x?xf32>
+// CHECK:             %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[N]], %[[C]], 0, 0] [2, 4, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<2x4x32x32xf32>
+// CHECK:             %[[CAST_OUT:.*]] = tensor.cast %[[SUB_OUT]]
+// CHECK:             %[[SUB_RES:.*]] = tensor.pack
+// CHECK-SAME:          %[[SUB_IN]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[CAST_OUT]]
+// CHECK:             %[[CAST_RES:.*]] = tensor.cast %[[SUB_RES]]
+// CHECK:             %[[INSERT:.*]] = tensor.insert_slice %[[CAST_RES]] into %[[ITER1]]
+// CHECK:             scf.yield %[[INSERT]] : tensor<4x8x32x32xf32>
+// CHECK:           }
+// CHECK:           scf.yield %[[RES1:.*]] : tensor<4x8x32x32xf32>
+// CHECK:         }
+// CHECK:         return %[[RES0:.*]] : tensor<4x8x32x32xf32>
+// CHECK:       }
+func.func @NC_to_NCnc(%arg0: tensor<128x256xf32>, %arg1: tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> {
+  %0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg1 : tensor<128x256xf32> -> tensor<4x8x32x32xf32>
+  return %0 : tensor<4x8x32x32xf32>
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg1
+    %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4]
+}
+
+// -----
+
+// CHECK:       #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 8)>
+// CHECK:       #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -8 + 256, 16)>
+// CHECK:       func.func @KC_to_CKkc
+// CHECK-SAME:    %[[IN:[A-Za-z0-9]+]]:
+// CHECK-SAME:    %[[OUT:[A-Za-z0-9]+]]:
+// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG:     %[[C32:.+]] = arith.constant 32 : index
+// CHECK:         scf.for %[[C:.+]] = %[[C0]] to %[[C32]] step %[[C2]]
+// CHECK-DAG:         %[[IN_C:.+]] = affine.apply #[[MAP0]](%[[C]])
+// CHECK-DAG:         %[[IN_C_SZ:.+]] = affine.min #[[MAP1]](%[[C]])
+// CHECK:             %[[INPUT_SLICE:.+]] = tensor.extract_slice %[[IN]]
+// CHECK-SAME:          [0, %[[IN_C]]] [128, %[[IN_C_SZ]]]
+// CHECK:             %[[CAST_IN:.+]] = tensor.cast %[[INPUT_SLICE]]
+// CHECK:             %[[OUTPUT_SLICE:.+]] = tensor.extract_slice %{{.+}}[%[[C]], 0, 0, 0] [2, 4, 32, 8]
+// CHECK:             %[[CAST_OUT:.+]] = tensor.cast %[[OUTPUT_SLICE]]
+// CHECK:             tensor.pack
+// CHECK-SAME:          %[[CAST_IN]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8]
+// CHECK-SAME:          into %[[CAST_OUT]]
+func.func @KC_to_CKkc(%arg0: tensor<128x256xf32>, %arg1: tensor<32x4x32x8xf32>) -> tensor<32x4x32x8xf32> {
+  %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<128x256xf32> -> tensor<32x4x32x8xf32>
+  return %0 : tensor<32x4x32x8xf32>
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg1
+    %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4]
+}
+
+// -----
+
+// CHECK-DAG:     #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 2)>
+// CHECK-DAG:     #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -2 + 15, 8)>
+// CHECK:         func.func @pad_and_pack_static(
+// CHECK-SAME:      %[[IN:.*]]: tensor<13x15xf32>,
+// CHECK-SAME:      %[[OUT:.*]]: tensor<2x8x8x2xf32>,
+// CHECK-SAME:      %[[PAD:.*]]: f32) -> tensor<2x8x8x2xf32> {
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:       %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:       %[[RES0:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[OUT]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK-DAG:         %[[IN_J:.*]] = affine.apply #[[MAP0]](%[[J]])
+// CHECK-DAG:         %[[IN_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]])
+// CHECK:             %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][0, %[[IN_J]]] [13, %[[IN_J_SZ]]] [1, 1]
+// CHECK:             %[[CAST_IN:.*]] = tensor.cast %[[SUB_IN]]
+// CHECK:             %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][0, %[[J]], 0, 0] [2, 4, 8, 2] [1, 1, 1, 1]
+// CHECK:             %[[CAST_OUT:.*]] = tensor.cast %[[SUB_OUT]]
+// CHECK:             %[[SUB_RES:.*]] = tensor.pack
+// CHECK-SAME:          %[[CAST_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2]
+// CHECK-SAME:          into %[[CAST_OUT]]
+// CHECK:             %[[CAST_RES:.*]] = tensor.cast %[[SUB_RES]]
+// CHECK:             %[[INSERT:.*]] = tensor.insert_slice %[[CAST_RES]] into %[[ITER1]]
+// CHECK:             scf.yield %[[INSERT]] : tensor<2x8x8x2xf32>
+// CHECK:           }
+// CHECK:           return %[[RES0:.*]] : tensor<2x8x8x2xf32>
+// CHECK:         }
+func.func @pad_and_pack_static(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: f32) -> tensor<2x8x8x2xf32> {
+  %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<13x15xf32> -> tensor<2x8x8x2xf32>
+  return %0 : tensor<2x8x8x2xf32>
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg1
+    %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4]
+}
+
+// -----
+
+// CHECK-DAG:     #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
+// CHECK-DAG:     #[[MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
+// CHECK-DAG:     #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 8)>
+// CHECK-DAG:     #[[MAP3:.+]] = affine_map<(d0, d1)[s0] -> (d1 * -8 + s0, d0 * 8)>
+// CHECK-DAG:     #[[MAP4:.+]] = affine_map<(d0) -> (d0 * 2)>
+// CHECK-DAG:     #[[MAP5:.+]] = affine_map<(d0, d1)[s0] -> (d1 * -2 + s0, d0 * 2)>
+// CHECK:         func.func @pad_and_pack_partially_dynamic(
+// CHECK-SAME:      %[[IN:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:      %[[OUT:.*]]: tensor<?x?x8x2xf32>,
+// CHECK-SAME:      %[[PAD:.*]]: f32) -> tensor<?x?x8x2xf32> {
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:       %[[OUT_D0:.*]] = tensor.dim %[[OUT]], %[[C0]] : tensor<?x?x8x2xf32>
+// CHECK-DAG:       %[[OUT_D1:.*]] = tensor.dim %[[OUT]], %[[C1]] : tensor<?x?x8x2xf32>
+// CHECK:           %[[RES0:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[OUT_D0]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor<?x?x8x2xf32>) {
+// CHECK-DAG:         %[[OUT_I_SZ:.*]] = affine.min #[[MAP0]](%[[I]])[%[[OUT_D0]]]
+// CHECK:             %[[RES1:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[OUT_D1]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor<?x?x8x2xf32>) {
+// CHECK-DAG:           %[[OUT_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]])[%[[OUT_D1]]]
+// CHECK-DAG:           %[[IN_I:.*]] = affine.apply #[[MAP2]](%[[I]])
+// CHECK-DAG:           %[[IN_I_SZ:.*]] = affine.min #[[MAP3]]
+// CHECK-DAG:           %[[IN_J:.*]] = affine.apply #[[MAP4]](%[[J]])
+// CHECK-DAG:           %[[IN_J_SZ:.*]] = affine.min #[[MAP5]]
+// CHECK:               %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_I]], %[[IN_J]]] [%[[IN_I_SZ]], %[[IN_J_SZ]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK:               %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[I]], %[[J]], 0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]], 8, 2] [1, 1, 1, 1] : tensor<?x?x8x2xf32> to tensor<?x?x8x2xf32>
+// CHECK:               %[[SUB_RES:.*]] = tensor.pack
+// CHECK-SAME:            %[[SUB_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2]
+// CHECK-SAME:            into %[[SUB_OUT]]
+// CHECK:               %[[INSERT:.*]] = tensor.insert_slice %[[SUB_RES]] into %[[ITER1]]
+// CHECK:               scf.yield %[[INSERT]] : tensor<?x?x8x2xf32>
+// CHECK:             }
+// CHECK:             scf.yield %[[RES1:.*]] : tensor<?x?x8x2xf32>
+// CHECK:           }
+// CHECK:           return %[[VAL_34:.*]] : tensor<?x?x8x2xf32>
+// CHECK:         }
+func.func @pad_and_pack_partially_dynamic(%input: tensor<?x?xf32>, %output: tensor<?x?x8x2xf32>, %pad: f32) -> tensor<?x?x8x2xf32> {
+  %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
+  return %0 : tensor<?x?x8x2xf32>
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg1
+    %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4]
+}
+
+// -----
+
+// CHECK-DAG:     #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
+// CHECK-DAG:     #[[MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
+// CHECK-DAG:     #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
+// CHECK-DAG:     #[[MAP3:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0, -(d1 * s0) + s1)>
+// CHECK:         func.func @pad_and_pack_fully_dynamic(
+// CHECK-SAME:      %[[IN:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:      %[[OUT:.*]]: tensor<?x?x?x?xf32>,
+// CHECK-SAME:      %[[PAD:.*]]: f32,
+// CHECK-SAME:      %[[TILE_0:.*]]: index,
+// CHECK-SAME:      %[[TILE_1:.*]]: index) -> tensor<?x?x?x?xf32> {
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:       %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:       %[[OUT_D0:.*]] = tensor.dim %[[OUT]], %[[C0]] : tensor<?x?x?x?xf32>
+// CHECK-DAG:       %[[OUT_D1:.*]] = tensor.dim %[[OUT]], %[[C1]] : tensor<?x?x?x?xf32>
+// CHECK:           %[[RES0:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[OUT_D0]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor<?x?x?x?xf32>) {
+// CHECK:             %[[OUT_I_SZ:.*]] = affine.min #[[MAP0]](%[[I]])[%[[OUT_D0]]]
+// CHECK:             %[[RES1:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[OUT_D1]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor<?x?x?x?xf32>) {
+// CHECK:               %[[OUT_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]])[%[[OUT_D1]]]
+// CHECK:               %[[IN_D0:.*]] = tensor.dim %[[IN]], %[[C0]]
+// CHECK:               %[[IN_D1:.*]] = tensor.dim %[[IN]], %[[C1]]
+// CHECK:               %[[IN_I:.*]] = affine.apply #[[MAP2]](%[[I]])[%[[TILE_0]]]
+// CHECK:               %[[IN_I_SZ:.*]] = affine.min #[[MAP3]](%[[OUT_I_SZ]], %[[I]])[%[[TILE_0]], %[[IN_D0]]]
+// CHECK:               %[[IN_J:.*]] = affine.apply #[[MAP2]](%[[J]])[%[[TILE_1]]]
+// CHECK:               %[[IN_J_SZ:.*]] = affine.min #[[MAP3]](%[[OUT_J_SZ]], %[[J]])[%[[TILE_1]], %[[IN_D1]]]
+// CHECK:               %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_I]], %[[IN_J]]] [%[[IN_I_SZ]], %[[IN_J_SZ]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK:               %[[OUT_D2:.+]] = tensor.dim %[[OUT]], %[[C2]]
+// CHECK:               %[[OUT_D3:.+]] = tensor.dim %[[OUT]], %[[C3]]
+// CHECK:               %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[I]], %[[J]], 0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]], %[[OUT_D2]], %[[OUT_D3]]] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
+// CHECK:               %[[PACK:.*]] = tensor.pack
+// CHECK-SAME:            %[[SUB_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [%[[TILE_0]], %[[TILE_1]]]
+// CHECK-SAME:            into %[[SUB_OUT]]
+// CHECK:               %[[INSERT:.*]] = tensor.insert_slice %[[PACK]] into %[[ITER1]]
+// CHECK:               scf.yield %[[INSERT]] : tensor<?x?x?x?xf32>
+// CHECK:             }
+// CHECK:             scf.yield %[[RES1:.*]] : tensor<?x?x?x?xf32>
+// CHECK:           }
+// CHECK:           return %[[RES0:.*]] : tensor<?x?x?x?xf32>
+// CHECK:         }
+func.func @pad_and_pack_fully_dynamic(%source: tensor<?x?xf32>, %dest: tensor<?x?x?x?xf32>, %pad: f32, %tile_n : index, %tile_m : index) -> tensor<?x?x?x?xf32> {
+  %0 = tensor.pack %source padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %dest : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg1
+    %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4]
+}
diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 8b03a80a574ee..0023e455e6515 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -5373,11 +5373,14 @@ cc_library(
     includes = ["include"],
     deps = [
         ":AffineDialect",
+        ":AffineUtils",
         ":ArithUtils",
+        ":DialectUtils",
         ":IR",
         ":LinalgDialect",
         ":SCFDialect",
         ":TensorDialect",
+        ":TensorUtils",
         ":TilingInterface",
         "//llvm:Support",
     ],
        
    
    
More information about the Mlir-commits
mailing list