[Mlir-commits] [mlir] 5443743 - [mlir][Linalg] Add a transform.structured.pack operation

Nicolas Vasilache llvmlistbot at llvm.org
Tue Jan 17 05:25:21 PST 2023


Author: Nicolas Vasilache
Date: 2023-01-17T05:14:50-08:00
New Revision: 5443743ca1874acfe2d5654fedd4a0c0bed6777e

URL: https://github.com/llvm/llvm-project/commit/5443743ca1874acfe2d5654fedd4a0c0bed6777e
DIFF: https://github.com/llvm/llvm-project/commit/5443743ca1874acfe2d5654fedd4a0c0bed6777e.diff

LOG: [mlir][Linalg] Add a transform.structured.pack operation

This revision introduces a `transform.structured.pack` operation to
transform any Linalg operation to a higher-dimensional Linalg operation on
packed operands.

`tensor.pack` (resp. `tensor.unpack`) operations are inserted for the operands
(resp. results) that need to be packed (resp. unpacked) according to the
`packed_sizes` specification.

At the moment, the packing operation always pads with `getZeroAttr` which will
need to be adjusted depending on the consumers.

Packing is limited to those dimensions that are indexed only by AffineDimExpr.
Packing more advanced indexings requires modular arithmetic that is outside the
scoped of a `linalg.generic` at the moment.

Differential Revision: https://reviews.llvm.org/D141860

Added: 
    mlir/test/Dialect/Linalg/transform-op-pack.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 4cf8802f41f1f..2d0557a0b0ff6 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -341,6 +341,90 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
   }];
 }
 
+def PackOp : Op<Transform_Dialect, "structured.pack", [
+                TransformOpInterface,
+                DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,]> {
+  let description = [{
+    Pack a LinalgOp by applying a data tiling transformation on the op and
+    packing the operands according to the `packed_sizes` specification.
+    
+    Iterator dimensions are tiled in their canonical order in the op spec.
+    Operands are packed according to the same canonical order of the op iterator
+    dimensions.
+
+    Specifying a packed size of 0 for an iterator removes it from consideration
+    for packing.
+
+    `tensor.pack` (resp. `tensor.unpack`) operations are inserted for the operands
+    (resp. results) that need to be packed (resp. unpacked) according to the
+    `packed_sizes` specification.
+
+    #### Example
+
+    Consider a `linalg.matmul` with indexing maps:
+    ```
+      //              M   N   K       M   K
+      // affine_map<(d0, d1, d2) -> (d0, d2)>
+      //                              K   N
+      // affine_map<(d0, d1, d2) -> (d2, d1)>
+      //                              M   N
+      // affine_map<(d0, d1, d2) -> (d0, d1)>
+      %0 = linalg.matmul  ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+                         outs(    %C: tensor<?x?xf32>)
+    ```
+
+    Specifying packed_sizes [2, 3, 4] results in tiling the iterator dimensions
+    M, N and K, in this order, in both the op and its operands.
+    ```
+      //              M   N   K   m   n   k       M   K   m   k
+      // affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+      //                                          K   N   n   k
+      // affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
+      //                                          M   N   m   n
+      // affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+      %0 = linalg.generic_representing_some_higher_d_matmul  
+            ins(%A, %B: tensor<?x?x2x4xf32>, tensor<?x?x4x3xf32>)
+           outs(    %C: tensor<?x?x2x4xf32>)
+    ```
+    In particular, note that the second operand `B` has shape `KxNxnxk` (and not
+    `KxNxkxn` as one could expect by looking **only** at the operand).
+
+    Other layouts can be obtained unsurprisingly from this canonical 
+    transformation by composing the resulting operation with a (future) 
+    `transform.structured.pack_transpose` op.
+    This composition allows separating concerns and composes better compared
+    to adding additional permutation attributes to this transform op.
+
+    #### Return modes
+
+    This operation applies to a single Linalg op, otherwise it fails.
+    This operation may produce a definiteFailure if the packing fails for any
+    reason.
+
+    The returned handle point to the packed LinalgOp.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                   Variadic<PDL_Operation>:$packed_sizes,
+                   DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_packed_sizes);
+  let results = (outs TransformHandleTypeInterface:$packed_op);
+  let assemblyFormat = [{
+    $target 
+    `packed_sizes` `=` custom<DynamicIndexList>($packed_sizes,
+                                                $static_packed_sizes)
+    attr-dict
+    `:` functional-type($target, results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure apply(
+      transform::TransformResults &transformResults,
+      transform::TransformState &state);
+
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedPackedSizes();
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // PadOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7d2f6da935ae8..3a33e0483e8c6 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -24,6 +24,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/ScopeExit.h"
@@ -35,6 +36,8 @@ using namespace mlir::linalg;
 using namespace mlir::transform;
 
 #define DEBUG_TYPE "linalg-transforms"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
 
 /// Attempts to apply the pattern specified as template argument to the given
 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
@@ -60,6 +63,67 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
   return cast<LinalgOp>(result->getOperation());
 }
 
+/// Assuming that `ofr` is an index attr or a transform dialect handle mapped
+/// to exactly one op with one index result, return that value.
+static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations(
+    transform::TransformState &state, TransformOpInterface transformOp,
+    SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
+  for (OpFoldResult ofr : ofrs) {
+    if (ofr.is<Attribute>()) {
+      if (!ofr.get<Attribute>().isa<IntegerAttr>())
+        return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+      result.push_back(ofr);
+      continue;
+    }
+    ArrayRef<Operation *> payloadOps = state.getPayloadOps(ofr.get<Value>());
+    if (payloadOps.size() != 1) {
+      DiagnosedSilenceableFailure diag =
+          transformOp.emitSilenceableError()
+          << "handle must be mapped to exactly one payload op";
+      diag.attachNote(ofr.get<Value>().getLoc())
+          << "mapped to " << payloadOps.size() << " payload ops";
+      return diag;
+    }
+
+    Operation *op = payloadOps[0];
+    if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+      DiagnosedSilenceableFailure diag =
+          transformOp.emitSilenceableError()
+          << "payload op must have exactly 1 index result";
+      diag.attachNote(op->getLoc())
+          << "has " << op->getNumResults() << " results";
+      return diag;
+    }
+    result.push_back(op->getResult(0));
+  }
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+// Given a list of OpFoldResults that are either index attrs or op
+// handles, return a list of OpFoldResults where all op handles are
+// replaced with the first (and only) OpResult of that payload op. (There
+// must be exactly one mapped payload op and it must have exactly one
+// index result.)
+static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations(
+    transform::TransformState &state, TransformOpInterface transformOp,
+    SmallVector<OpFoldResult> &result, Value packedHandle) {
+  ArrayRef<Operation *> payloadOps = state.getPayloadOps(packedHandle);
+  for (Operation *op : payloadOps) {
+    if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+      DiagnosedSilenceableFailure diag =
+          transformOp.emitSilenceableError()
+          << "payload op must have exactly 1 index result";
+      diag.attachNote(op->getLoc())
+          << "has " << op->getNumResults() << " results";
+      return diag;
+    }
+    result.push_back(op->getResult(0));
+  }
+
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // DecomposeOp
 //===----------------------------------------------------------------------===//
@@ -743,6 +807,334 @@ void transform::MultiTileSizesOp::getEffects(
   modifiesPayload(effects);
 }
 
+//===---------------------------------------------------------------------===//
+// PackOp
+//===---------------------------------------------------------------------===//
+
+SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
+  Builder b(getContext());
+  return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
+}
+
+/// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim).
+static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) {
+  bool found = false;
+  for (AffineExpr e : map.getResults()) {
+    if (!e.isFunctionOfDim(dim))
+      continue;
+    if (found)
+      return false;
+    found = true;
+  }
+  return true;
+}
+
+/// Return the index of the first result of `map` that is a function of
+/// AffineDimExpr(dim), std::nullopt otherwise.
+static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map,
+                                                            int64_t dim) {
+  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+    AffineExpr expr = map.getResult(i);
+    if (!expr.isFunctionOfDim(dim))
+      continue;
+    return i;
+  }
+  return std::nullopt;
+}
+
+/// Perform one step of packing of a LinalgOp's metadata along `dim` into the
+/// `newDim` at `iteratorTypes.size()` by:
+///   1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`.
+///   2. Appending a `newDim` to the domain of every indexing map.
+///   3. For each operand (i.e. for each map in `indexingMaps`), perform packing
+///      by potentially adding a `newDim` result to `map`.
+/// The preserved invariant is that `iteratorTypes.size()` is always equal to
+/// `map.getNumDims()` for every map in `indexingMaps`.
+///
+/// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update.
+/// Return a vector that records the optional packing for each operand.
+/// Return failure if the packed indexing cannot be represented with a LinalgOp.
+///
+/// Further details:
+/// ================
+/// The current implementation of packing (i.e. data tiling) consists of
+/// rewriting a linearized strip-mined form into a higher-dimensional access.
+/// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite
+/// `I` into `4 * i + ii`, where `0 <= ii < 4`.
+/// The access is further rewritten as `A[i][f(j, k, l)][ii]`.
+///
+/// This rewrite into higher dimensional access is not possible for general
+/// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr:
+/// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we
+/// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`.
+/// The rewrite of the access would be a form not representable in Linalg:
+///   `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`.
+/// Note however that as `J` and `ii` iterate, the accesses do not have a
+/// particular alignment, so packing does not achieve alignment in this case
+///
+/// In the future, we may want to consider a mixed-form that allows some
+/// alignment in the presence of multiple accesses:
+///   `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]`
+/// And would rewrite accesses as:
+///   `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]`
+static FailureOr<SmallVector<std::optional<int64_t>>>
+packLinalgMetadataOnce(SmallVectorImpl<AffineMap> &indexingMaps,
+                       SmallVectorImpl<utils::IteratorType> &iteratorTypes,
+                       int64_t dim) {
+  int64_t newDim = iteratorTypes.size();
+  iteratorTypes.push_back(iteratorTypes[dim]);
+
+  SmallVector<std::optional<int64_t>> packedDimPerIndexingMap(
+      indexingMaps.size(), std::nullopt);
+  SmallVector<AffineMap> newMaps;
+  for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
+       ++operandIdx) {
+    AffineMap map = indexingMaps[operandIdx];
+
+    // Add the `newDim` to map whatever the case.
+    assert(map.getNumDims() == newDim && "num dims invariant violation");
+    map = map.shiftDims(1, newDim);
+
+    // Get the at-most-1 index of the result that is a function of `dim`.
+    // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which
+    // logically chunks dimension `dim` into `K * dim + newDim`, where the
+    // packing factor `K` is specified separately.
+    assert(hasAtMostOneResultFunctionOfDim(map, dim) &&
+           "num results invariant violation");
+    auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim);
+    if (!maybeOperandDimensionToPack.has_value()) {
+      newMaps.push_back(map);
+      continue;
+    }
+
+    // We can only pack AffineDimExpr atm.
+    if (!map.getResult(maybeOperandDimensionToPack.value())
+             .isa<AffineDimExpr>())
+      return failure();
+
+    // Add `newDim` to the results of the map.
+    map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim),
+                           map.getNumResults());
+    newMaps.push_back(map);
+
+    // Record the that `operandIdx` is packed.
+    packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
+  }
+  indexingMaps = newMaps;
+
+  return packedDimPerIndexingMap;
+}
+
+namespace {
+
+/// Helper struct to encode packing along one dimension of a LinalgOp.
+struct PackedOperandsDim {
+  OpFoldResult packedSize;
+  SmallVector<std::optional<int64_t>> packedDimForEachOperand;
+};
+
+/// Helper struct to encode packing along all dimensions of a LinalgOp.
+struct PackedOperandsDimList {
+  void push_back(PackedOperandsDim &&packedOperandsDims) {
+    spec.emplace_back(packedOperandsDims);
+  }
+  /// Return all the dims that have been packed for operand @ `operandPos`.
+  SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
+  /// Return all the pack sizes by which an operand @ `operandPos` is packed.
+  SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
+
+private:
+  SmallVector<PackedOperandsDim> spec;
+};
+
+} // namespace
+
+SmallVector<int64_t>
+PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
+  SmallVector<int64_t> res;
+  for (int64_t i = 0, e = spec.size(); i < e; ++i) {
+    if (!spec[i].packedDimForEachOperand[operandPos].has_value())
+      continue;
+    res.push_back(spec[i].packedDimForEachOperand[operandPos].value());
+  }
+  return res;
+}
+
+SmallVector<OpFoldResult>
+PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
+  SmallVector<OpFoldResult> res;
+  for (int64_t i = 0, e = spec.size(); i < e; ++i) {
+    if (!spec[i].packedDimForEachOperand[operandPos].has_value())
+      continue;
+    res.push_back(spec[i].packedSize);
+  }
+  return res;
+}
+
+/// Implement packing of a single LinalgOp by performing packing by
+/// `packedSizeHandles`. There must be one packedSizeHandles entry per
+/// `linalgOp` iterator. Return the packed Linalg op on success, failure
+/// otherwise.
+static FailureOr<linalg::LinalgOp>
+packOneLinalgOp(RewriterBase &rewriter, transform::TransformState &state,
+                TransformOpInterface transformOp, linalg::LinalgOp linalgOp,
+                ArrayRef<OpFoldResult> packedSizeHandles) {
+  assert(packedSizeHandles.size() == linalgOp.getNumLoops() &&
+         "incorrect number of pack sizes");
+
+  Location loc = linalgOp->getLoc();
+  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
+  SmallVector<utils::IteratorType> iteratorTypes =
+      linalgOp.getIteratorTypesArray();
+  LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n";
+             llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
+             llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: ");
+             DBGSNL(););
+
+  // Unpack handles to constants or actual SSA index values.
+  SmallVector<OpFoldResult> packedSizes;
+  DiagnosedSilenceableFailure status = unpackSingleIndexResultPDLOperations(
+      state, transformOp, packedSizes, packedSizeHandles);
+
+  // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i].
+  PackedOperandsDimList listOfPackedOperandsDim;
+  for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
+    std::optional<int64_t> maybeConstant = getConstantIntValue(packedSizes[i]);
+    // Skip tile sizes explicitly set to 0.
+    if (maybeConstant.has_value() && maybeConstant.value() == 0)
+      continue;
+
+    PackedOperandsDim packedOperandsDims;
+    packedOperandsDims.packedSize = packedSizes[i];
+    FailureOr<SmallVector<std::optional<int64_t>>>
+        maybePackedDimForEachOperand =
+            packLinalgMetadataOnce(indexingMaps, iteratorTypes, i);
+    if (failed(maybePackedDimForEachOperand))
+      return failure();
+    packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
+    listOfPackedOperandsDim.push_back(std::move(packedOperandsDims));
+
+    LLVM_DEBUG(
+        DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i]
+               << "\n";
+        llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
+        llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); DBGSNL();
+        llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand,
+                              DBGS() << "packedDimForEachOperand: ");
+        DBGSNL(););
+  }
+
+  // Step 2. Propagate packing to all LinalgOp operands.
+  SmallVector<Value> inputsAndInits, results;
+  for (auto operandsList :
+       {linalgOp.getDpsInputOperands(), linalgOp.getDpsInitOperands()}) {
+    for (OpOperand *opOperandPtr : operandsList) {
+      int64_t pos = opOperandPtr->getOperandNumber();
+      Value operand = opOperandPtr->get();
+      SmallVector<int64_t> innerPos =
+          listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
+      SmallVector<OpFoldResult> innerPackSizes =
+          listOfPackedOperandsDim.extractPackSizesForOperand(pos);
+      LLVM_DEBUG(
+          DBGS() << "operand: " << operand << "\n";
+          llvm::interleaveComma(innerPos, DBGS() << "innerPos: "); DBGSNL();
+          llvm::interleaveComma(innerPackSizes, DBGS() << "innerPackSizes: ");
+          DBGSNL(););
+      if (innerPackSizes.empty()) {
+        inputsAndInits.push_back(operand);
+        continue;
+      }
+      Value dest = tensor::PackOp::createDestinationTensor(
+          rewriter, loc, operand, innerPackSizes, innerPos,
+          /*outerDimsPerm=*/{});
+      // TODO: value of the padding attribute should be determined by consumers.
+      Attribute zeroAttr =
+          rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
+      Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
+      inputsAndInits.push_back(rewriter.create<tensor::PackOp>(
+          loc, operand, dest, innerPos, innerPackSizes, zero));
+    }
+  }
+
+  // Step 3. Build the packed op, use the type of `inits` as result types.
+  ValueRange inputs =
+      ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
+  ValueRange inits =
+      ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
+  auto packedLinalgOp = rewriter.create<linalg::GenericOp>(
+      linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps,
+      iteratorTypes);
+  packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
+
+  // Step 4. Propagate packing to all the op results.
+  for (OpResult result : packedLinalgOp->getResults()) {
+    int64_t resultNum = result.getResultNumber();
+    tensor::PackOp maybePackedInit =
+        inits[resultNum].getDefiningOp<tensor::PackOp>();
+    if (!maybePackedInit) {
+      results.push_back(result);
+      continue;
+    }
+    // Build the symmetrical UnPackOp to the existing PackOp.
+    results.push_back(rewriter.create<tensor::UnPackOp>(
+        packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
+        maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
+  }
+
+  // Step 5. Replace `linalgOp`.
+  rewriter.replaceOp(linalgOp, results);
+
+  // Return packedLinalgOp.
+  return cast<linalg::LinalgOp>(packedLinalgOp.getOperation());
+}
+
+DiagnosedSilenceableFailure
+transform::PackOp::apply(transform::TransformResults &transformResults,
+                         transform::TransformState &state) {
+  ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
+  // If nothing to pack, propagate success.
+  if (targetOps.empty()) {
+    transformResults.set(getPackedOp().cast<OpResult>(), {});
+    return DiagnosedSilenceableFailure::success();
+  }
+  // Fail on multi-op handles.
+  auto linalgOp = dyn_cast<linalg::LinalgOp>(targetOps.front());
+  if (targetOps.size() != 1 || !linalgOp) {
+    // TODO: remove this unnecessary set to empty once crashes are fixed.
+    transformResults.set(getPackedOp().cast<OpResult>(), {});
+    return emitSilenceableError()
+           << "requires target to map to exactly 1 LinalgOp (got "
+           << targetOps.size() << ")";
+  }
+  // Fail on mismatched number of pack sizes.
+  if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
+    // TODO: remove this unnecessary set to empty once crashes are fixed.
+    transformResults.set(getPackedOp().cast<OpResult>(), {});
+    return emitSilenceableError()
+           << "requires number of packed sizes match the number of loops ("
+           << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
+           << ")";
+  }
+
+  IRRewriter rewriter(linalgOp->getContext());
+  rewriter.setInsertionPoint(linalgOp);
+  FailureOr<LinalgOp> maybeResult =
+      packOneLinalgOp(rewriter, state, *this, linalgOp, getMixedPackedSizes());
+  if (failed(maybeResult))
+    return emitDefiniteFailure("data tiling failed");
+
+  transformResults.set(getPackedOp().cast<OpResult>(),
+                       maybeResult->getOperation());
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::PackOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::consumesHandle(getTarget(), effects);
+  transform::onlyReadsHandle(getPackedSizes(), effects);
+  transform::producesHandle(getPackedOp(), effects);
+}
+
 //===---------------------------------------------------------------------===//
 // PadOp
 //===---------------------------------------------------------------------===//
@@ -1608,68 +2000,6 @@ void transform::TileToForeachThreadOp::build(
         /*mapping=*/mapping);
 }
 
-/// Assuming that `ofr` is an index attr or a transform dialect handle mapped
-/// to exactly one op with one index result, return that value.
-static DiagnosedSilenceableFailure unpackPDLOperations(
-    transform::TransformState &state, TransformOpInterface transformOp,
-    SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
-  for (OpFoldResult ofr : ofrs) {
-    if (ofr.is<Attribute>()) {
-      if (!ofr.get<Attribute>().isa<IntegerAttr>())
-        return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
-      result.push_back(ofr);
-      continue;
-    }
-    ArrayRef<Operation *> payloadOps = state.getPayloadOps(ofr.get<Value>());
-    if (payloadOps.size() != 1) {
-      DiagnosedSilenceableFailure diag =
-          transformOp.emitSilenceableError()
-          << "handle must be mapped to exactly one payload op";
-      diag.attachNote(ofr.get<Value>().getLoc())
-          << "mapped to " << payloadOps.size() << " payload ops";
-      return diag;
-    }
-
-    Operation *op = payloadOps[0];
-    if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
-      DiagnosedSilenceableFailure diag =
-          transformOp.emitSilenceableError()
-          << "payload op must have exactly 1 index result";
-      diag.attachNote(op->getLoc())
-          << "has " << op->getNumResults() << " results";
-      return diag;
-    }
-    result.push_back(op->getResult(0));
-  }
-
-  return DiagnosedSilenceableFailure::success();
-}
-
-// Given a list of OpFoldResults that are either index attrs or op
-// handles, return a list of OpFoldResults where all op handles are
-// replaced with the first (and only) OpResult of that payload op. (There
-// must be exactly one mapped payload op and it must have exactly one
-// index result.)
-static DiagnosedSilenceableFailure
-unpackPDLOperations(transform::TransformState &state,
-                    TransformOpInterface transformOp,
-                    SmallVector<OpFoldResult> &result, Value packedHandle) {
-  ArrayRef<Operation *> payloadOps = state.getPayloadOps(packedHandle);
-  for (Operation *op : payloadOps) {
-    if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
-      DiagnosedSilenceableFailure diag =
-          transformOp.emitSilenceableError()
-          << "payload op must have exactly 1 index result";
-      diag.attachNote(op->getLoc())
-          << "has " << op->getNumResults() << " results";
-      return diag;
-    }
-    result.push_back(op->getResult(0));
-  }
-
-  return DiagnosedSilenceableFailure::success();
-}
-
 DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
     RewriterBase &rewriter, transform::TransformState &state,
     TransformOpInterface transformOp, ArrayRef<Operation *> targets,
@@ -1724,18 +2054,18 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
   SmallVector<OpFoldResult> mixedNumThreads;
   DiagnosedSilenceableFailure status =
       getPackedNumThreads()
-          ? unpackPDLOperations(state, transformOp, mixedNumThreads,
-                                getPackedNumThreads())
-          : unpackPDLOperations(state, transformOp, mixedNumThreads,
-                                getMixedNumThreads());
+          ? unpackSingleIndexResultPDLOperations(
+                state, transformOp, mixedNumThreads, getPackedNumThreads())
+          : unpackSingleIndexResultPDLOperations(
+                state, transformOp, mixedNumThreads, getMixedNumThreads());
   if (!status.succeeded())
     return status;
   SmallVector<OpFoldResult> mixedTileSizes;
   status = getPackedTileSizes()
-               ? unpackPDLOperations(state, transformOp, mixedTileSizes,
-                                     getPackedTileSizes())
-               : unpackPDLOperations(state, transformOp, mixedTileSizes,
-                                     getMixedTileSizes());
+               ? unpackSingleIndexResultPDLOperations(
+                     state, transformOp, mixedTileSizes, getPackedTileSizes())
+               : unpackSingleIndexResultPDLOperations(
+                     state, transformOp, mixedTileSizes, getMixedTileSizes());
   if (!status.succeeded())
     return status;
 

diff  --git a/mlir/test/Dialect/Linalg/transform-op-pack.mlir b/mlir/test/Dialect/Linalg/transform-op-pack.mlir
new file mode 100644
index 0000000000000..d1304bb2be483
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-pack.mlir
@@ -0,0 +1,406 @@
+// RUN: mlir-opt -test-transform-dialect-interpreter -split-input-file -verify-diagnostics -allow-unregistered-dialect %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+#reduction_2d_trait = {
+  indexing_maps = [#map, #map1],
+  iterator_types = ["parallel", "reduction"]
+}
+
+//    CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//    CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2) -> (d0)>
+
+//  CHECK-LABEL: @reduction_2d_static
+//   CHECK-SAME:   %[[T0:.+]]: tensor<3x7xf16>,
+//   CHECK-SAME:   %[[T1:.+]]: tensor<3xf16>
+func.func @reduction_2d_static(%t0: tensor<3x7xf16>, %t1: tensor<3xf16>) -> tensor<3xf16> {
+  //      CHECK:  %[[EMPTY:.*]] = tensor.empty() : tensor<3x2x4xf16>
+  //      CHECK: %[[PACKED:.*]] = tensor.pack %[[T0]] padding_value(%{{.*}} : f16) 
+  // CHECK-SAME:   inner_dims_pos = [1] inner_tiles = [4] into %[[EMPTY]] : tensor<3x7xf16> -> tensor<3x2x4xf16>
+  //  CHECK-NOT: tensor.pack
+  //      CHECK: linalg.generic 
+  // CHECK-SAME:   indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]]]
+  // CHECK-SAME:   iterator_types = ["parallel", "reduction", "reduction"]
+  // CHECK-SAME:   ins(%{{.*}} : tensor<3x2x4xf16>)
+  // CHECK-SAME:  outs(%{{.*}} : tensor<3xf16>)
+  %2 = linalg.generic #reduction_2d_trait ins(%t0 : tensor<3x7xf16>) outs(%t1 : tensor<3xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    %3 = arith.addf %in, %out : f16
+    linalg.yield %3 : f16
+  } -> tensor<3xf16>
+
+  //  CHECK-NOT: tensor.unpack
+  return %2 : tensor<3xf16>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  transform.structured.pack %0 packed_sizes = [0, 4]
+      : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1)>
+#col_reduction_2d_trait = {
+  indexing_maps = [#map, #map1],
+  iterator_types = ["reduction", "parallel"]
+}
+
+//    CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//    CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2) -> (d1)>
+
+//  CHECK-LABEL: @col_reduction_2d_static
+//   CHECK-SAME:   %[[T0:.+]]: tensor<7x3xf16>,
+//   CHECK-SAME:   %[[T1:.+]]: tensor<3xf16>
+func.func @col_reduction_2d_static(%t0: tensor<7x3xf16>, %t1: tensor<3xf16>) -> tensor<3xf16> {
+  //      CHECK:  %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x4xf16>
+  //      CHECK: %[[PACKED:.*]] = tensor.pack %[[T0]] padding_value(%{{.*}} : f16) 
+  // CHECK-SAME:   inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<7x3xf16> -> tensor<2x3x4xf16>
+  //  CHECK-NOT: tensor.pack
+  //      CHECK: linalg.generic 
+  // CHECK-SAME:   indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]]]
+  // CHECK-SAME:   iterator_types = ["reduction", "parallel", "reduction"]
+  // CHECK-SAME:   ins(%{{.*}} : tensor<2x3x4xf16>)
+  // CHECK-SAME:  outs(%{{.*}} : tensor<3xf16>)
+  %2 = linalg.generic #col_reduction_2d_trait ins(%t0 : tensor<7x3xf16>) outs(%t1 : tensor<3xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    %3 = arith.addf %in, %out : f16
+    linalg.yield %3 : f16
+  } -> tensor<3xf16>
+
+  //  CHECK-NOT: tensor.unpack
+  return %2 : tensor<3xf16>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  transform.structured.pack %0 packed_sizes = [4, 0]
+      : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+#reduction_2d_trait = {
+  indexing_maps = [#map, #map1],
+  iterator_types = ["parallel", "reduction"]
+}
+
+//    CHECK-DAG:     #[[$DIV4:.*]] = affine_map<()[s0] -> (s0 ceildiv 4)>
+//    CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//    CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2) -> (d0)>
+
+//  CHECK-LABEL: @reduction_2d_dynamic
+//   CHECK-SAME:   %[[T0:.+]]: tensor<?x?xf16>,
+//   CHECK-SAME:   %[[T1:.+]]: tensor<?xf16>
+func.func @reduction_2d_dynamic(%t0: tensor<?x?xf16>, %t1: tensor<?xf16>) -> tensor<?xf16> {
+  //  CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+  //  CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+  //  CHECK-DAG:     %[[D0:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?xf16>
+  //  CHECK-DAG:     %[[D1:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?xf16>
+  //      CHECK:   %[[D1B4:.*]] = affine.apply #[[$DIV4]]()[%[[D1]]]
+  //      CHECK:  %[[EMPTY:.*]] = tensor.empty(%[[D0]], %[[D1B4]]) : tensor<?x?x4xf16>
+  //      CHECK: %[[PACKED:.*]] = tensor.pack %[[T0]] padding_value(%{{.*}} : f16) 
+  // CHECK-SAME:   inner_dims_pos = [1] inner_tiles = [4] into %[[EMPTY]] : tensor<?x?xf16> -> tensor<?x?x4xf16>
+  //  CHECK-NOT: tensor.pack
+  //      CHECK: linalg.generic 
+  // CHECK-SAME:   indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]]]
+  // CHECK-SAME:   iterator_types = ["parallel", "reduction", "reduction"]
+  // CHECK-SAME:   ins(%{{.*}} : tensor<?x?x4xf16>)
+  // CHECK-SAME:  outs(%{{.*}} : tensor<?xf16>)
+  %2 = linalg.generic #reduction_2d_trait ins(%t0 : tensor<?x?xf16>) outs(%t1 : tensor<?xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    %3 = arith.addf %in, %out : f16
+    linalg.yield %3 : f16
+  } -> tensor<?xf16>
+
+  //  CHECK-NOT: tensor.unpack
+  return %2 : tensor<?xf16>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  transform.structured.pack %0 packed_sizes = [0, 4]
+      : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}
+
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+#reduction_2d_trait = {
+  indexing_maps = [#map, #map1],
+  iterator_types = ["parallel", "reduction"]
+}
+
+//    CHECK-DAG:     #[[$DIV3:.*]] = affine_map<()[s0] -> (s0 ceildiv 3)>
+//    CHECK-DAG:     #[[$DIV4:.*]] = affine_map<()[s0] -> (s0 ceildiv 4)>
+//    CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+//    CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
+
+//  CHECK-LABEL: @reduction_2d_dynamic
+//   CHECK-SAME:   %[[T0:.+]]: tensor<?x?xf16>,
+//   CHECK-SAME:   %[[T1:.+]]: tensor<?xf16>
+func.func @reduction_2d_dynamic(%t0: tensor<?x?xf16>, %t1: tensor<?xf16>) -> tensor<?xf16> {
+  //      CHECK: %[[PACKED_0:.*]] = tensor.pack %[[T0]] padding_value(%{{.*}} : f16) 
+  // CHECK-SAME:   inner_dims_pos = [0, 1] inner_tiles = [3, 4] into %{{.*}} : tensor<?x?xf16> -> tensor<?x?x3x4xf16>
+  //      CHECK: %[[PACKED_1:.*]] = tensor.pack %[[T1]] padding_value(%{{.*}} : f16) 
+  // CHECK-SAME:   inner_dims_pos = [0] inner_tiles = [3] into %{{.*}} : tensor<?xf16> -> tensor<?x3xf16>
+  //  CHECK-NOT: tensor.pack
+  //      CHECK: linalg.generic 
+  // CHECK-SAME:   indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]]]
+  // CHECK-SAME:   iterator_types = ["parallel", "reduction", "parallel", "reduction"]
+  // CHECK-SAME:   ins(%{{.*}} : tensor<?x?x3x4xf16>)
+  // CHECK-SAME:  outs(%{{.*}} : tensor<?x3xf16>)
+  %2 = linalg.generic #reduction_2d_trait ins(%t0 : tensor<?x?xf16>) outs(%t1 : tensor<?xf16>) {
+  ^bb0(%in: f16, %out: f16):
+    %3 = arith.addf %in, %out : f16
+    linalg.yield %3 : f16
+  } -> tensor<?xf16>
+
+  //      CHECK: tensor.unpack %{{.*}} inner_dims_pos = [0] inner_tiles = [3] into %{{.*}} : tensor<?x3xf16> -> tensor<?xf16>
+  return %2 : tensor<?xf16>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  transform.structured.pack %0 packed_sizes = [3, 4]
+      : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}
+
+// -----
+
+//                                                M   N   K   m   n   k       M   K   m   k
+// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+//                                                                            K   N   n   k
+// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
+//                                                                            M   N   m   n
+// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: @matmul
+//  CHECK-SAME:   %[[A:[0-9a-zA-Z]+]]: tensor<?x?xf32>,
+//  CHECK-SAME:   %[[B:[0-9a-zA-Z]+]]: tensor<?x?xf32>,
+//  CHECK-SAME:   %[[C:[0-9a-zA-Z]+]]: tensor<?x?xf32>
+func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
+
+  //      CHECK: %[[PACK_A:.*]] = tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 4]
+  // CHECK-SAME:   : tensor<?x?xf32> -> tensor<?x?x2x4xf32>
+  //      CHECK: %[[PACK_B:.*]] = tensor.pack %{{.*}} inner_dims_pos = [1, 0] inner_tiles = [3, 4]
+  // CHECK-SAME:   : tensor<?x?xf32> -> tensor<?x?x3x4xf32>
+  //      CHECK: %[[PACK_C:.*]] = tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 3]
+  // CHECK-SAME:   : tensor<?x?xf32> -> tensor<?x?x2x3xf32>
+
+  //      CHECK: linalg.generic {indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]], #[[$PACKED_MAP_2]]]
+  // CHECK-SAME:     iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} 
+  // CHECK-SAME:  ins(%{{.*}} : tensor<?x?x2x4xf32>, tensor<?x?x3x4xf32>)
+  // CHECK-SAME: outs(%{{.*}} : tensor<?x?x2x3xf32>)
+  %0 = linalg.matmul  ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+                     outs(%C: tensor<?x?xf32>)
+    -> tensor<?x?xf32>
+
+  //      CHECK: tensor.unpack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 3]
+  // CHECK-SAME:   : tensor<?x?x2x3xf32> -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+    //                                            M  N  K
+    %1 = transform.structured.pack %0 packed_sizes = [2, 3, 4]
+      : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}
+
+// -----
+
+//                                                N   F   H   W   C  KH  KW   f   c
+// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d4, d2 + d5, d3 + d6, d8)>
+// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d4, d5, d6, d7, d8)>
+// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d7)>
+
+// CHECK-LABEL: @conv_2d_nchw_fchw
+//  CHECK-SAME:   %[[INPUT:.+]]: tensor<14x512x28x28xf32>,
+//  CHECK-SAME:   %[[FILTER:.+]]: tensor<1024x512x1x1xf32>
+//  CHECK-SAME:   %[[INIT:.+]]: tensor<14x1024x28x28xf32>
+func.func @conv_2d_nchw_fchw(%i: tensor<14x512x28x28xf32>, %f: tensor<1024x512x1x1xf32>,
+                             %o: tensor<14x1024x28x28xf32>) -> tensor<14x1024x28x28xf32> {
+
+  //      CHECK: %[[PACK_INPUT:.*]] = tensor.pack %{{.*}} inner_dims_pos = [1] inner_tiles = [8]
+  // CHECK-SAME:   : tensor<14x512x28x28xf32> -> tensor<14x64x28x28x8xf32>
+  //      CHECK: %[[PACK_FILTER:.*]] = tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [4, 8]
+  // CHECK-SAME:   : tensor<1024x512x1x1xf32> -> tensor<256x64x1x1x4x8xf32>
+  //      CHECK: %[[PACK_INPUT:.*]] = tensor.pack %{{.*}} inner_dims_pos = [1] inner_tiles = [4]
+  // CHECK-SAME:   : tensor<14x1024x28x28xf32> -> tensor<14x256x28x28x4xf32>
+  //      CHECK: linalg.generic {indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]], #[[$PACKED_MAP_2]]]
+  // CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction", "parallel", "reduction"]} 
+  // CHECK-SAME:  ins(%{{.*}} : tensor<14x64x28x28x8xf32>, tensor<256x64x1x1x4x8xf32>)
+  // CHECK-SAME: outs(%{{.*}} : tensor<14x256x28x28x4xf32>)
+  %0 = linalg.conv_2d_nchw_fchw ins(%i, %f: tensor<14x512x28x28xf32>, tensor<1024x512x1x1xf32>)
+                                outs(%o: tensor<14x1024x28x28xf32>) -> tensor<14x1024x28x28xf32>
+
+  //      CHECK: tensor.unpack %{{.*}} inner_dims_pos = [1] inner_tiles = [4]
+  // CHECK-SAME:   : tensor<14x256x28x28x4xf32> -> tensor<14x1024x28x28xf32>
+  return %0: tensor<14x1024x28x28xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match interface{LinalgOp} in %arg1
+  //                                            N  F  H  W  C KH KW
+  %1 = transform.structured.pack %0 packed_sizes = [0, 4, 0, 0, 8, 0, 0]
+      : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}
+
+// -----
+
+//                                                N   H   W   F  KH  KW   C   f   c
+// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d4, d2 + d5, d6, d8)>
+// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d4, d5, d6, d3, d7, d8)>
+// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d7)>
+
+// CHECK-LABEL: @conv_2d_nhwc_hwcf
+//  CHECK-SAME:   %[[INPUT:.+]]: tensor<?x1x?x?xf32>,
+//  CHECK-SAME:   %[[FILTER:.+]]: tensor<1x?x?x?xf32>
+//  CHECK-SAME:   %[[INIT:.+]]: tensor<?x1x?x?xf32>
+func.func @conv_2d_nhwc_hwcf(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?x?x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+  
+  //      CHECK: %[[PACK_INPUT:.*]] = tensor.pack %{{.*}} inner_dims_pos = [3] inner_tiles = [6]
+  // CHECK-SAME:   : tensor<?x1x?x?xf32> -> tensor<?x1x?x?x6xf32>
+  //      CHECK: %[[PACK_FILTER:.*]] = tensor.pack %{{.*}} inner_dims_pos = [3, 2] inner_tiles = [4, 6]
+  // CHECK-SAME:   : tensor<1x?x?x?xf32> -> tensor<1x?x?x?x4x6xf32>
+  //      CHECK: %[[PACK_OUTPUT:.*]] = tensor.pack %{{.*}} inner_dims_pos = [3] inner_tiles = [4]
+  // CHECK-SAME:   : tensor<?x1x?x?xf32> -> tensor<?x1x?x?x4xf32>
+
+  //      CHECK: linalg.generic {indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]], #[[$PACKED_MAP_2]]]
+  // CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction", "parallel", "reduction"]} 
+  // CHECK-SAME:  ins(%{{.*}} : tensor<?x1x?x?x6xf32>, tensor<1x?x?x?x4x6xf32>)
+  // CHECK-SAME: outs(%{{.*}} : tensor<?x1x?x?x4xf32>)
+  %0 = linalg.conv_2d_nhwc_hwcf
+     ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?x?x?xf32>)
+    outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+  
+  //      CHECK: tensor.unpack %{{.*}} inner_dims_pos = [3] inner_tiles = [4]
+  // CHECK-SAME:   : tensor<?x1x?x?x4xf32> -> tensor<?x1x?x?xf32>
+  return %0 : tensor<?x1x?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match interface{LinalgOp} in %arg1
+  //                                            N  H  W  F KH KW  C
+  %1 = transform.structured.pack %0 packed_sizes = [0, 0, 0, 4, 0, 0, 6]
+      : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}
+
+// -----
+
+// CHECK-DAG: affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+//                                                M   N   K    n   k      M   K   k
+// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>
+//                                                                        K   N   n   k
+// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d1, d3, d4)>
+//                                                                        M   N    n
+// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+
+// CHECK-LABEL: @matmul_dynamic_pack_size
+//  CHECK-SAME:   %[[A:[0-9a-zA-Z]+]]: tensor<?x?xf32>,
+//  CHECK-SAME:   %[[B:[0-9a-zA-Z]+]]: tensor<?x?xf32>,
+//  CHECK-SAME:   %[[C:[0-9a-zA-Z]+]]: tensor<?x?xf32>
+func.func @matmul_dynamic_pack_size(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
+  //      CHECK: %[[TS:.*]] = "some_tile_size"() : () -> index
+  %sz = "some_tile_size"() : () -> (index)
+
+  //      CHECK: %[[PACK_A:.*]] = tensor.pack %[[A]] {{.*}} inner_dims_pos = [1] inner_tiles = [%[[TS]]]
+  // CHECK-SAME:   : tensor<?x?xf32> -> tensor<?x?x?xf32>
+  //      CHECK: %[[PACK_B:.*]] = tensor.pack %[[B]] {{.*}} inner_dims_pos = [1, 0] inner_tiles = [%[[TS]], %[[TS]]]
+  // CHECK-SAME:   : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+  //      CHECK: %[[PACK_C:.*]] = tensor.pack %[[C]] {{.*}} inner_dims_pos = [1] inner_tiles = [%[[TS]]]
+  // CHECK-SAME:   : tensor<?x?xf32> -> tensor<?x?x?xf32>
+  //      CHECK: linalg.generic {indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]], #[[$PACKED_MAP_2]]]
+  // CHECK-SAME:     iterator_types = ["parallel", "parallel", "reduction", "parallel", "reduction"]} 
+  // CHECK-SAME:  ins(%{{.*}} : tensor<?x?x?xf32>, tensor<?x?x?x?xf32>)
+  // CHECK-SAME: outs(%{{.*}} : tensor<?x?x?xf32>)
+  %0 = linalg.matmul  ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+                     outs(%C: tensor<?x?xf32>)
+    -> tensor<?x?xf32>
+
+  //      CHECK: tensor.unpack %{{.*}} inner_dims_pos = [1] inner_tiles = [%[[TS]]] into %[[C]]
+  // CHECK-SAME:   : tensor<?x?x?xf32> -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+    %sz = transform.structured.match ops{["some_tile_size"]} in %arg1
+    %1 = transform.structured.pack %0 packed_sizes = [0, %sz, %sz] 
+      : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}
+
+// -----
+
+func.func @conv_cant_pack(%i: tensor<14x512x28x28xf32>, %f: tensor<1024x512x1x1xf32>,
+                          %o: tensor<14x1024x28x28xf32>) -> tensor<14x1024x28x28xf32> {
+  %0 = linalg.conv_2d_nchw_fchw ins(%i, %f: tensor<14x512x28x28xf32>, tensor<1024x512x1x1xf32>)
+                                outs(%o: tensor<14x1024x28x28xf32>) -> tensor<14x1024x28x28xf32>
+  return %0: tensor<14x1024x28x28xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match interface{LinalgOp} in %arg1
+  //                                                N  F  H  W  C KH KW
+  // expected-error @below {{data tiling failed}}
+  %1 = transform.structured.pack %0 packed_sizes = [0, 0, 4, 0, 0, 0, 0]
+      : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}
+
+// -----
+
+func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
+    -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+  %0 = linalg.matmul  ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+                     outs(%C: tensor<?x?xf32>)
+    -> tensor<?x?xf32>
+  %1 = linalg.matmul  ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+                     outs(%C: tensor<?x?xf32>)
+    -> tensor<?x?xf32>
+  return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+    // expected-error @below {{requires target to map to exactly 1 LinalgOp (got 2)}}
+    %1 = transform.structured.pack %0 packed_sizes = [2, 3, 4] 
+      : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}
+
+
+// -----
+
+func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
+  %0 = linalg.matmul  ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+                     outs(%C: tensor<?x?xf32>)
+    -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+    // expected-error @below {{requires number of packed sizes match the number of loops (2 vs 3)}}
+    %1 = transform.structured.pack %0 packed_sizes = [2, 3] 
+      : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+}


        


More information about the Mlir-commits mailing list