[Mlir-commits] [mlir] 4d74c84 - [MLIR][Linalg] Expose `packMatmulGreedily` in `Transforms.h` (NFC)

Lorenzo Chelini llvmlistbot at llvm.org
Thu Jul 6 03:00:40 PDT 2023


Author: Lorenzo Chelini
Date: 2023-07-06T11:59:17+02:00
New Revision: 4d74c845a1c2519639cbbcf23fb25f7a488667e4

URL: https://github.com/llvm/llvm-project/commit/4d74c845a1c2519639cbbcf23fb25f7a488667e4
DIFF: https://github.com/llvm/llvm-project/commit/4d74c845a1c2519639cbbcf23fb25f7a488667e4.diff

LOG: [MLIR][Linalg] Expose `packMatmulGreedily` in `Transforms.h` (NFC)

Make the transformation accessible to other drivers (i.e., passes).

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 8be49d05c2f0e0..fe48ba5d035995 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -677,8 +677,8 @@ def PackOp : Op<Transform_Dialect, "structured.pack", [
     In particular, note that the second operand `B` has shape `KxNxnxk` (and not
     `KxNxkxn` as one could expect by looking **only** at the operand).
 
-    Other layouts can be obtained unsurprisingly from this canonical 
-    transformation by composing the resulting operation with a (future) 
+    Other layouts can be obtained unsurprisingly from this canonical
+    transformation by composing the resulting operation with a
     `transform.structured.pack_transpose` op.
     This composition allows separating concerns and composes better compared
     to adding additional permutation attributes to this transform op.

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 968488cead1720..d02f798c72030a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1051,6 +1051,20 @@ packTranspose(RewriterBase &rewriter, tensor::PackOp packOp,
               linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
               ArrayRef<int64_t> outerPerm, ArrayRef<int64_t> innerPerm);
 
+/// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
+/// and n are proper parallel dimensions and k is a proper reduction
+/// dimension. Packing occurs by rewriting the op as a linalg.generic and
+/// calling linalg::pack by `mnkPackedSizes`. The order of the packed
+/// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
+/// to reorder {m, n, k} into one of the 8 possible forms. The outer
+/// dimensions of the operands are not permuted at this time, this is left for
+/// future work.
+FailureOr<PackResult>
+packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
+                   ArrayRef<OpFoldResult> mnkPackedSizes,
+                   ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
+                   ArrayRef<int64_t> mnkOrder);
+
 /// Rewrite tensor.from_elements to linalg.generic.
 FailureOr<Operation *>
 rewriteInDestinationPassingStyle(RewriterBase &rewriter,

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index a9e675b902e765..440272ae8a6c47 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1244,145 +1244,6 @@ LogicalResult transform::PackGreedilyOp::verify() {
   return success();
 }
 
-/// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
-/// and n are proper parallel dimensions and k is a proper reduction
-/// dimension. Packing occurs by rewriting the op as a linalg.generic and
-/// calling linalg::pack by `mnkPackedSizes`. The order of the packed
-/// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
-/// to reorder {m, n, k} into one of the 8 possible forms. The outer
-/// dimensions of the operands are not permuted at this time, this is left for
-/// future work.
-static FailureOr<PackResult>
-packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
-                   ArrayRef<OpFoldResult> mnkPackedSizes,
-                   ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
-                   ArrayRef<int64_t> mnkOrder) {
-  assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
-  assert((mnkPaddedSizesNextMultipleOf.empty() ||
-          mnkPaddedSizesNextMultipleOf.size() == 3) &&
-         "num of packing sizes next multiple should be empty or of size 3");
-  assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
-  assert(isPermutationVector(mnkOrder) && "expected a permutation");
-
-  int64_t numLoops = linalgOp.getNumLoops();
-  if (numLoops <= 2) {
-    LDBG("need 3+ loops to find a matmul to pack, got "
-         << numLoops << "\nin: " << linalgOp << "\n");
-    return rewriter.notifyMatchFailure(
-        linalgOp, "need 3+ loops to find a matmul to pack");
-  }
-
-  // Locally adjust the desired iterator position of mnk and packing sizes.
-  int64_t numPackedDims = mnkPackedSizes.size();
-  SmallVector<int64_t> mmnnkkPos(numPackedDims);
-  for (int64_t i = 0, e = numPackedDims; i < e; ++i)
-    mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
-  SmallVector<OpFoldResult> packedSizes(numPackedDims);
-  for (int64_t i = 0, e = numPackedDims; i < e; ++i)
-    packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
-  SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims);
-  for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
-    paddedSizesNextMultipleOf[mnkOrder[i]] =
-        mnkPaddedSizesNextMultipleOf.empty() ? 0
-                                             : mnkPaddedSizesNextMultipleOf[i];
-  }
-
-  // 1. Infer dims that are important for matmul.
-  FailureOr<ContractionDimensions> maybeDimensions =
-      inferContractionDims(linalgOp);
-  if (failed(maybeDimensions)) {
-    LDBG("couldn't infer matmul iterators in: " << linalgOp << "\n");
-    return rewriter.notifyMatchFailure(linalgOp,
-                                       "couldn't infer matmul iterators");
-  }
-
-  // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
-  // minor iterators. In cases with multiple options for m, n, k bias towards
-  // the most minor embedding.
-  // If we wanted a 
diff erent normalization order, this is where it would have
-  // to plug a heuristic.
-  int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
-          kPos = maybeDimensions->k.back();
-  LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
-             DBGS() << "Start packing generic op greedily with (m@" << mPos
-                    << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
-                    << "\n";);
-
-  // 2.a. Rewrite as a generic.
-  auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
-  if (!genericOp) {
-    FailureOr<GenericOp> generalizeResult =
-        generalizeNamedOp(rewriter, linalgOp);
-    assert(succeeded(generalizeResult) && "unexpected failure generalizing op");
-    genericOp = *generalizeResult;
-  }
-
-  // 2.b. Interchange to move the dimensions (k, m, n) as most-minor
-  // iterators. Note that this only normalized the iteration order and does
-  // not change the indexings of any operand.
-  SmallVector<int64_t> permutation =
-      computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
-  LLVM_DEBUG(llvm::interleaveComma(permutation, DBGS() << "perm: "); DBGSNL(););
-  // Sign .. unsigned pollution.
-  SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
-  FailureOr<GenericOp> interchangeResult =
-      interchangeGenericOp(rewriter, genericOp, unsignedPerm);
-  assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
-  genericOp = *interchangeResult;
-  LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";);
-
-  // At this point, the op iterators are normalized to {leading, k, m, n}.
-  // The layouts induced by packing will always be:
-  //   - LHS{leading_lhs, kk, mm}
-  //   - RHS{leading_rhs, kk, nn}
-  //   - RES{leading_res, mm, nn}
-  // If we wanted to change the packed order, we would reorder (k, m, n) to
-  // something else above.
-  //
-  // Additional permutations of the outer dims of the operands (i.e.
-  // leading_lhs, leading_rhs and leading_res) could follow by computing the
-  // desired outerPerm for each operand.
-  // This is left for future work.
-
-  // TODO: this creates too much IR, go use reifyResultShapes.
-  SmallVector<Range, 4> loopRanges =
-      cast<LinalgOp>(genericOp.getOperation())
-          .createLoopRanges(rewriter, genericOp.getLoc());
-
-  // Add leading zeros to match numLoops, we only pack the last 3 dimensions
-  // post interchange.
-  LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf,
-                                   DBGS() << "paddedSizesNextMultipleOf: ");
-             DBGSNL(););
-  LLVM_DEBUG(llvm::interleaveComma(loopRanges, DBGS() << "loopRanges: ",
-                                   [](Range r) { llvm::dbgs() << r.size; });
-             DBGSNL(););
-  SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
-                                                rewriter.getIndexAttr(0));
-  for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
-    if (paddedSizesNextMultipleOf[i] == 0) {
-      adjustedPackedSizes.push_back(packedSizes[i]);
-      continue;
-    }
-    AffineExpr d0, s0;
-    bindDims(rewriter.getContext(), d0);
-    bindSymbols(rewriter.getContext(), s0);
-    adjustedPackedSizes.push_back(affine::makeComposedFoldedAffineApply(
-        rewriter, genericOp->getLoc(), d0.ceilDiv(s0) * s0,
-        {loopRanges[adjustedPackedSizes.size()].size,
-         rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
-  }
-  LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes,
-                                   DBGS() << "adjustedPackedSizes: ");
-             DBGSNL(););
-
-  // TODO: If we wanted to give the genericOp a name after packing, after
-  // calling `pack` would be a good time. One would still need to check that
-  // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we
-  // also allow degenerate matmul cases (i.e. matvec, dot).
-  return linalg::pack(rewriter, genericOp, adjustedPackedSizes);
-}
-
 DiagnosedSilenceableFailure
 PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
                       transform::TransformResults &transformResults,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 2a41bcc14e08fe..ce5dd46ad8f44d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -752,6 +752,150 @@ linalg::packTranspose(RewriterBase &rewriter, tensor::PackOp packOp,
                              transposedUnPackOp};
 }
 
+//===----------------------------------------------------------------------===//
+// packMatmulGreedily transformation.
+//===----------------------------------------------------------------------===//
+
+/// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
+/// and n are proper parallel dimensions and k is a proper reduction
+/// dimension. Packing occurs by rewriting the op as a linalg.generic and
+/// calling linalg::pack by `mnkPackedSizes`. The order of the packed
+/// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
+/// to reorder {m, n, k} into one of the 8 possible forms. The outer
+/// dimensions of the operands are not permuted at this time, this is left for
+/// future work.
+FailureOr<PackResult>
+linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
+                           ArrayRef<OpFoldResult> mnkPackedSizes,
+                           ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
+                           ArrayRef<int64_t> mnkOrder) {
+  assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
+  assert((mnkPaddedSizesNextMultipleOf.empty() ||
+          mnkPaddedSizesNextMultipleOf.size() == 3) &&
+         "num of packing sizes next multiple should be empty or of size 3");
+  assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
+  assert(isPermutationVector(mnkOrder) && "expected a permutation");
+
+  int64_t numLoops = linalgOp.getNumLoops();
+  if (numLoops <= 2) {
+    LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got "
+                      << numLoops << "\nin: " << linalgOp << "\n");
+    return rewriter.notifyMatchFailure(
+        linalgOp, "need 3+ loops to find a matmul to pack");
+  }
+
+  // Locally adjust the desired iterator position of mnk and packing sizes.
+  int64_t numPackedDims = mnkPackedSizes.size();
+  SmallVector<int64_t> mmnnkkPos(numPackedDims);
+  for (int64_t i = 0, e = numPackedDims; i < e; ++i)
+    mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
+  SmallVector<OpFoldResult> packedSizes(numPackedDims);
+  for (int64_t i = 0, e = numPackedDims; i < e; ++i)
+    packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
+  SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims);
+  for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
+    paddedSizesNextMultipleOf[mnkOrder[i]] =
+        mnkPaddedSizesNextMultipleOf.empty() ? 0
+                                             : mnkPaddedSizesNextMultipleOf[i];
+  }
+
+  // 1. Infer dims that are important for matmul.
+  FailureOr<ContractionDimensions> maybeDimensions =
+      inferContractionDims(linalgOp);
+  if (failed(maybeDimensions)) {
+    LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp
+                      << "\n");
+    return rewriter.notifyMatchFailure(linalgOp,
+                                       "couldn't infer matmul iterators");
+  }
+
+  // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
+  // minor iterators. In cases with multiple options for m, n, k bias towards
+  // the most minor embedding.
+  // If we wanted a 
diff erent normalization order, this is where it would have
+  // to plug a heuristic.
+  int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
+          kPos = maybeDimensions->k.back();
+  LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
+             DBGS() << "Start packing generic op greedily with (m@" << mPos
+                    << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
+                    << "\n";);
+
+  // 2.a. Rewrite as a generic.
+  auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
+  if (!genericOp) {
+    FailureOr<GenericOp> generalizeResult =
+        generalizeNamedOp(rewriter, linalgOp);
+    assert(succeeded(generalizeResult) && "unexpected failure generalizing op");
+    genericOp = *generalizeResult;
+  }
+
+  // 2.b. Interchange to move the dimensions (k, m, n) as most-minor
+  // iterators. Note that this only normalized the iteration order and does
+  // not change the indexings of any operand.
+  SmallVector<int64_t> permutation =
+      computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
+  LLVM_DEBUG(llvm::interleaveComma(permutation, DBGS() << "perm: "); DBGSNL(););
+  // Sign .. unsigned pollution.
+  SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
+  FailureOr<GenericOp> interchangeResult =
+      interchangeGenericOp(rewriter, genericOp, unsignedPerm);
+  assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
+  genericOp = *interchangeResult;
+  LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";);
+
+  // At this point, the op iterators are normalized to {leading, k, m, n}.
+  // The layouts induced by packing will always be:
+  //   - LHS{leading_lhs, kk, mm}
+  //   - RHS{leading_rhs, kk, nn}
+  //   - RES{leading_res, mm, nn}
+  // If we wanted to change the packed order, we would reorder (k, m, n) to
+  // something else above.
+  //
+  // Additional permutations of the outer dims of the operands (i.e.
+  // leading_lhs, leading_rhs and leading_res) could follow by computing the
+  // desired outerPerm for each operand.
+  // This is left for future work.
+
+  // TODO: this creates too much IR, go use reifyResultShapes.
+  SmallVector<Range, 4> loopRanges =
+      cast<LinalgOp>(genericOp.getOperation())
+          .createLoopRanges(rewriter, genericOp.getLoc());
+
+  // Add leading zeros to match numLoops, we only pack the last 3 dimensions
+  // post interchange.
+  LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf,
+                                   DBGS() << "paddedSizesNextMultipleOf: ");
+             DBGSNL(););
+  LLVM_DEBUG(llvm::interleaveComma(loopRanges, DBGS() << "loopRanges: ",
+                                   [](Range r) { llvm::dbgs() << r.size; });
+             DBGSNL(););
+  SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
+                                                rewriter.getIndexAttr(0));
+  for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
+    if (paddedSizesNextMultipleOf[i] == 0) {
+      adjustedPackedSizes.push_back(packedSizes[i]);
+      continue;
+    }
+    AffineExpr d0, s0;
+    bindDims(rewriter.getContext(), d0);
+    bindSymbols(rewriter.getContext(), s0);
+    adjustedPackedSizes.push_back(affine::makeComposedFoldedAffineApply(
+        rewriter, genericOp->getLoc(), d0.ceilDiv(s0) * s0,
+        {loopRanges[adjustedPackedSizes.size()].size,
+         rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
+  }
+  LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes,
+                                   DBGS() << "adjustedPackedSizes: ");
+             DBGSNL(););
+
+  // TODO: If we wanted to give the genericOp a name after packing, after
+  // calling `pack` would be a good time. One would still need to check that
+  // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we
+  // also allow degenerate matmul cases (i.e. matvec, dot).
+  return pack(rewriter, genericOp, adjustedPackedSizes);
+}
+
 //===----------------------------------------------------------------------===//
 // Transformations exposed as rewrite patterns.
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list