[Mlir-commits] [mlir] 02371c5 - [mlir][Linalg] NFC - Expose packing implementation as a standalone functional-style API call

Nicolas Vasilache llvmlistbot at llvm.org
Tue Jan 24 02:30:46 PST 2023


Author: Nicolas Vasilache
Date: 2023-01-24T02:27:40-08:00
New Revision: 02371c5d666392ae526993dd5331c5ea0caa2840

URL: https://github.com/llvm/llvm-project/commit/02371c5d666392ae526993dd5331c5ea0caa2840
DIFF: https://github.com/llvm/llvm-project/commit/02371c5d666392ae526993dd5331c5ea0caa2840.diff

LOG: [mlir][Linalg] NFC - Expose packing implementation as a standalone functional-style API call

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7e8507d3c9df2..07f56770b63c8 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -219,7 +219,7 @@ struct TiledLinalgOp {
 FailureOr<TiledLinalgOp> tileLinalgOp(RewriterBase &b, LinalgOp op,
                                       const LinalgTilingOptions &options);
 
-/// Try to peel anad canonicalize loop `op` and return the new result.
+/// Try to peel and canonicalize loop `op` and return the new result.
 // TODO: Add support for scf.parallel and affine.for loops.
 SmallVector<Value> peelLoop(RewriterBase &rewriter, Operation *op);
 /// Peel and canonicalize 'loops'.
@@ -1141,6 +1141,12 @@ FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
     GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
     RewriterBase &rewriter);
 
+/// Implement packing of a single LinalgOp by performing packing by
+/// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator.
+/// Return the packed Linalg op on success, failure otherwise.
+FailureOr<linalg::LinalgOp> pack(RewriterBase &rewriter,
+                                 linalg::LinalgOp linalgOp,
+                                 ArrayRef<OpFoldResult> packedSizes);
 } // namespace linalg
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 3df23733452b2..8d06ef6cca0c5 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -898,280 +898,6 @@ SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
   return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
 }
 
-#ifndef NDEBUG
-/// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim).
-static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) {
-  bool found = false;
-  for (AffineExpr e : map.getResults()) {
-    if (!e.isFunctionOfDim(dim))
-      continue;
-    if (found)
-      return false;
-    found = true;
-  }
-  return true;
-}
-#endif // NDEBUG
-
-/// Return the index of the first result of `map` that is a function of
-/// AffineDimExpr(dim), std::nullopt otherwise.
-static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map,
-                                                            int64_t dim) {
-  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
-    AffineExpr expr = map.getResult(i);
-    if (!expr.isFunctionOfDim(dim))
-      continue;
-    return i;
-  }
-  return std::nullopt;
-}
-
-/// Perform one step of packing of a LinalgOp's metadata along `dim` into the
-/// `newDim` at `iteratorTypes.size()` by:
-///   1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`.
-///   2. Appending a `newDim` to the domain of every indexing map.
-///   3. For each operand (i.e. for each map in `indexingMaps`), perform packing
-///      by potentially adding a `newDim` result to `map`.
-/// The preserved invariant is that `iteratorTypes.size()` is always equal to
-/// `map.getNumDims()` for every map in `indexingMaps`.
-///
-/// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update.
-/// Return a vector that records the optional packing for each operand.
-/// Return failure if the packed indexing cannot be represented with a LinalgOp.
-///
-/// Further details:
-/// ================
-/// The current implementation of packing (i.e. data tiling) consists of
-/// rewriting a linearized strip-mined form into a higher-dimensional access.
-/// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite
-/// `I` into `4 * i + ii`, where `0 <= ii < 4`.
-/// The access is further rewritten as `A[i][f(j, k, l)][ii]`.
-///
-/// This rewrite into higher dimensional access is not possible for general
-/// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr:
-/// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we
-/// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`.
-/// The rewrite of the access would be a form not representable in Linalg:
-///   `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`.
-/// Note however that as `J` and `ii` iterate, the accesses do not have a
-/// particular alignment, so packing does not achieve alignment in this case
-///
-/// In the future, we may want to consider a mixed-form that allows some
-/// alignment in the presence of multiple accesses:
-///   `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]`
-/// And would rewrite accesses as:
-///   `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]`
-static FailureOr<SmallVector<std::optional<int64_t>>>
-packLinalgMetadataOnce(SmallVectorImpl<AffineMap> &indexingMaps,
-                       SmallVectorImpl<utils::IteratorType> &iteratorTypes,
-                       int64_t dim) {
-  int64_t newDim = iteratorTypes.size();
-  iteratorTypes.push_back(iteratorTypes[dim]);
-
-  SmallVector<std::optional<int64_t>> packedDimPerIndexingMap(
-      indexingMaps.size(), std::nullopt);
-  SmallVector<AffineMap> newMaps;
-  for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
-       ++operandIdx) {
-    AffineMap map = indexingMaps[operandIdx];
-
-    // Add the `newDim` to map whatever the case.
-    assert(map.getNumDims() == newDim && "num dims invariant violation");
-    map = map.shiftDims(1, newDim);
-
-    // Get the at-most-1 index of the result that is a function of `dim`.
-    // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which
-    // logically chunks dimension `dim` into `K * dim + newDim`, where the
-    // packing factor `K` is specified separately.
-    assert(hasAtMostOneResultFunctionOfDim(map, dim) &&
-           "num results invariant violation");
-    auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim);
-    if (!maybeOperandDimensionToPack.has_value()) {
-      newMaps.push_back(map);
-      continue;
-    }
-
-    // We can only pack AffineDimExpr atm.
-    if (!map.getResult(maybeOperandDimensionToPack.value())
-             .isa<AffineDimExpr>())
-      return failure();
-
-    // Add `newDim` to the results of the map.
-    map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim),
-                           map.getNumResults());
-    newMaps.push_back(map);
-
-    // Record the that `operandIdx` is packed.
-    packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
-  }
-  indexingMaps = newMaps;
-
-  return packedDimPerIndexingMap;
-}
-
-namespace {
-
-/// Helper struct to encode packing along one dimension of a LinalgOp.
-struct PackedOperandsDim {
-  OpFoldResult packedSize;
-  SmallVector<std::optional<int64_t>> packedDimForEachOperand;
-};
-
-/// Helper struct to encode packing along all dimensions of a LinalgOp.
-struct PackedOperandsDimList {
-  void push_back(PackedOperandsDim &&packedOperandsDims) {
-    spec.emplace_back(packedOperandsDims);
-  }
-  /// Return all the dims that have been packed for operand @ `operandPos`.
-  SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
-  /// Return all the pack sizes by which an operand @ `operandPos` is packed.
-  SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
-
-private:
-  SmallVector<PackedOperandsDim> spec;
-};
-
-} // namespace
-
-SmallVector<int64_t>
-PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
-  SmallVector<int64_t> res;
-  for (int64_t i = 0, e = spec.size(); i < e; ++i) {
-    if (!spec[i].packedDimForEachOperand[operandPos].has_value())
-      continue;
-    res.push_back(spec[i].packedDimForEachOperand[operandPos].value());
-  }
-  return res;
-}
-
-SmallVector<OpFoldResult>
-PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
-  SmallVector<OpFoldResult> res;
-  for (int64_t i = 0, e = spec.size(); i < e; ++i) {
-    if (!spec[i].packedDimForEachOperand[operandPos].has_value())
-      continue;
-    res.push_back(spec[i].packedSize);
-  }
-  return res;
-}
-
-/// Implement packing of a single LinalgOp by performing packing by
-/// `packedSizeHandles`. There must be one packedSizeHandles entry per
-/// `linalgOp` iterator. Return the packed Linalg op on success, failure
-/// otherwise.
-static FailureOr<linalg::LinalgOp>
-packOneLinalgOp(RewriterBase &rewriter, transform::TransformState &state,
-                TransformOpInterface transformOp, linalg::LinalgOp linalgOp,
-                ArrayRef<OpFoldResult> packedSizeHandles) {
-  assert(packedSizeHandles.size() == linalgOp.getNumLoops() &&
-         "incorrect number of pack sizes");
-
-  Location loc = linalgOp->getLoc();
-  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
-  SmallVector<utils::IteratorType> iteratorTypes =
-      linalgOp.getIteratorTypesArray();
-  LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n";
-             llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
-             llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: ");
-             DBGSNL(););
-
-  // Unpack handles to constants or actual SSA index values.
-  SmallVector<OpFoldResult> packedSizes;
-  DiagnosedSilenceableFailure status = unpackSingleIndexResultPDLOperations(
-      state, transformOp, packedSizes, packedSizeHandles);
-
-  // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i].
-  PackedOperandsDimList listOfPackedOperandsDim;
-  for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
-    std::optional<int64_t> maybeConstant = getConstantIntValue(packedSizes[i]);
-    // Skip tile sizes explicitly set to 0.
-    if (maybeConstant.has_value() && maybeConstant.value() == 0)
-      continue;
-
-    PackedOperandsDim packedOperandsDims;
-    packedOperandsDims.packedSize = packedSizes[i];
-    FailureOr<SmallVector<std::optional<int64_t>>>
-        maybePackedDimForEachOperand =
-            packLinalgMetadataOnce(indexingMaps, iteratorTypes, i);
-    if (failed(maybePackedDimForEachOperand))
-      return failure();
-    packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
-    listOfPackedOperandsDim.push_back(std::move(packedOperandsDims));
-
-    LLVM_DEBUG(
-        DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i]
-               << "\n";
-        llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
-        llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); DBGSNL();
-        llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand,
-                              DBGS() << "packedDimForEachOperand: ");
-        DBGSNL(););
-  }
-
-  // Step 2. Propagate packing to all LinalgOp operands.
-  SmallVector<Value> inputsAndInits, results;
-  for (auto operandsList :
-       {linalgOp.getDpsInputOperands(), linalgOp.getDpsInitOperands()}) {
-    for (OpOperand *opOperandPtr : operandsList) {
-      int64_t pos = opOperandPtr->getOperandNumber();
-      Value operand = opOperandPtr->get();
-      SmallVector<int64_t> innerPos =
-          listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
-      SmallVector<OpFoldResult> innerPackSizes =
-          listOfPackedOperandsDim.extractPackSizesForOperand(pos);
-      LLVM_DEBUG(
-          DBGS() << "operand: " << operand << "\n";
-          llvm::interleaveComma(innerPos, DBGS() << "innerPos: "); DBGSNL();
-          llvm::interleaveComma(innerPackSizes, DBGS() << "innerPackSizes: ");
-          DBGSNL(););
-      if (innerPackSizes.empty()) {
-        inputsAndInits.push_back(operand);
-        continue;
-      }
-      Value dest = tensor::PackOp::createDestinationTensor(
-          rewriter, loc, operand, innerPackSizes, innerPos,
-          /*outerDimsPerm=*/{});
-      // TODO: value of the padding attribute should be determined by consumers.
-      Attribute zeroAttr =
-          rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
-      Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
-      inputsAndInits.push_back(rewriter.create<tensor::PackOp>(
-          loc, operand, dest, innerPos, innerPackSizes, zero));
-    }
-  }
-
-  // Step 3. Build the packed op, use the type of `inits` as result types.
-  ValueRange inputs =
-      ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
-  ValueRange inits =
-      ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
-  auto packedLinalgOp = rewriter.create<linalg::GenericOp>(
-      linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps,
-      iteratorTypes);
-  packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
-
-  // Step 4. Propagate packing to all the op results.
-  for (OpResult result : packedLinalgOp->getResults()) {
-    int64_t resultNum = result.getResultNumber();
-    tensor::PackOp maybePackedInit =
-        inits[resultNum].getDefiningOp<tensor::PackOp>();
-    if (!maybePackedInit) {
-      results.push_back(result);
-      continue;
-    }
-    // Build the symmetrical UnPackOp to the existing PackOp.
-    results.push_back(rewriter.create<tensor::UnPackOp>(
-        packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
-        maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
-  }
-
-  // Step 5. Replace `linalgOp`.
-  rewriter.replaceOp(linalgOp, results);
-
-  // Return packedLinalgOp.
-  return cast<linalg::LinalgOp>(packedLinalgOp.getOperation());
-}
-
 DiagnosedSilenceableFailure
 transform::PackOp::apply(transform::TransformResults &transformResults,
                          transform::TransformState &state) {
@@ -1196,10 +922,14 @@ transform::PackOp::apply(transform::TransformResults &transformResults,
            << ")";
   }
 
+  // Unpack handles to constants or actual SSA index values.
+  SmallVector<OpFoldResult> packedSizes;
+  DiagnosedSilenceableFailure status = unpackSingleIndexResultPDLOperations(
+      state, *this, packedSizes, getMixedPackedSizes());
+
   IRRewriter rewriter(linalgOp->getContext());
   rewriter.setInsertionPoint(linalgOp);
-  FailureOr<LinalgOp> maybeResult =
-      packOneLinalgOp(rewriter, state, *this, linalgOp, getMixedPackedSizes());
+  FailureOr<LinalgOp> maybeResult = pack(rewriter, linalgOp, packedSizes);
   if (failed(maybeResult))
     return emitDefiniteFailure("data tiling failed");
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 9f374991aed3c..8e8657d64e851 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -45,6 +45,7 @@ using namespace mlir;
 using namespace mlir::linalg;
 
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
 
 //===----------------------------------------------------------------------===//
 // Transformations exposed as rewrite patterns.
@@ -694,39 +695,39 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
 
   // Get domain indices based on conv2D layout.
   auto [khIndex, kwIndex, ohIndex, owIndex] =
-      TypeSwitch<Operation *, std::tuple<int64_t, int64_t, int64_t,
-                                         int64_t>>(convOp)
-      .Case([&](linalg::Conv2DNhwcHwcfOp op) {
-        return std::make_tuple(0, 1, 1, 2);
-      })
-      .Case([&](linalg::Conv2DNchwFchwOp op) {
-        return std::make_tuple(2, 3, 2, 3);
-      })
-      .Case([&](linalg::PoolingNhwcSumOp op) {
-        return std::make_tuple(0, 1, 1, 2);
-      })
-      .Case([&](linalg::PoolingNchwSumOp op) {
-        return std::make_tuple(0, 1, 2, 3);
-      })
-      .Case([&](linalg::PoolingNhwcMaxOp op) {
-        return std::make_tuple(0, 1, 1, 2);
-      })
-      .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
-        return std::make_tuple(0, 1, 1, 2);
-      })
-      .Case([&](linalg::PoolingNhwcMinOp op) {
-        return std::make_tuple(0, 1, 1, 2);
-      })
-      .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
-        return std::make_tuple(0, 1, 1, 2);
-      })
-      .Case([&](linalg::PoolingNchwMaxOp op) {
-        return std::make_tuple(0, 1, 2, 3);
-      })
-      .Default([&](Operation *op) {
-        llvm_unreachable("unexpected conv2d/pool2d operation.");
-        return std::make_tuple(0, 0, 0, 0);
-      });
+      TypeSwitch<Operation *, std::tuple<int64_t, int64_t, int64_t, int64_t>>(
+          convOp)
+          .Case([&](linalg::Conv2DNhwcHwcfOp op) {
+            return std::make_tuple(0, 1, 1, 2);
+          })
+          .Case([&](linalg::Conv2DNchwFchwOp op) {
+            return std::make_tuple(2, 3, 2, 3);
+          })
+          .Case([&](linalg::PoolingNhwcSumOp op) {
+            return std::make_tuple(0, 1, 1, 2);
+          })
+          .Case([&](linalg::PoolingNchwSumOp op) {
+            return std::make_tuple(0, 1, 2, 3);
+          })
+          .Case([&](linalg::PoolingNhwcMaxOp op) {
+            return std::make_tuple(0, 1, 1, 2);
+          })
+          .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
+            return std::make_tuple(0, 1, 1, 2);
+          })
+          .Case([&](linalg::PoolingNhwcMinOp op) {
+            return std::make_tuple(0, 1, 1, 2);
+          })
+          .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
+            return std::make_tuple(0, 1, 1, 2);
+          })
+          .Case([&](linalg::PoolingNchwMaxOp op) {
+            return std::make_tuple(0, 1, 2, 3);
+          })
+          .Default([&](Operation *op) {
+            llvm_unreachable("unexpected conv2d/pool2d operation.");
+            return std::make_tuple(0, 0, 0, 0);
+          });
 
   // Only handle the case where at least one of the window dimensions is
   // of size 1. Other cases can rely on tiling to reduce to such cases.
@@ -887,3 +888,276 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
       DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>(
       patterns.getContext(), benefit);
 }
+
+//===----------------------------------------------------------------------===//
+// pack transformation.
+//===----------------------------------------------------------------------===//
+
+#ifndef NDEBUG
+/// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim).
+static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) {
+  bool found = false;
+  for (AffineExpr e : map.getResults()) {
+    if (!e.isFunctionOfDim(dim))
+      continue;
+    if (found)
+      return false;
+    found = true;
+  }
+  return true;
+}
+#endif // NDEBUG
+
+/// Return the index of the first result of `map` that is a function of
+/// AffineDimExpr(dim), std::nullopt otherwise.
+static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map,
+                                                            int64_t dim) {
+  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+    AffineExpr expr = map.getResult(i);
+    if (!expr.isFunctionOfDim(dim))
+      continue;
+    return i;
+  }
+  return std::nullopt;
+}
+
+/// Perform one step of packing of a LinalgOp's metadata along `dim` into the
+/// `newDim` at `iteratorTypes.size()` by:
+///   1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`.
+///   2. Appending a `newDim` to the domain of every indexing map.
+///   3. For each operand (i.e. for each map in `indexingMaps`), perform packing
+///      by potentially adding a `newDim` result to `map`.
+/// The preserved invariant is that `iteratorTypes.size()` is always equal to
+/// `map.getNumDims()` for every map in `indexingMaps`.
+///
+/// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update.
+/// Return a vector that records the optional packing for each operand.
+/// Return failure if the packed indexing cannot be represented with a LinalgOp.
+///
+/// Further details:
+/// ================
+/// The current implementation of packing (i.e. data tiling) consists of
+/// rewriting a linearized strip-mined form into a higher-dimensional access.
+/// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite
+/// `I` into `4 * i + ii`, where `0 <= ii < 4`.
+/// The access is further rewritten as `A[i][f(j, k, l)][ii]`.
+///
+/// This rewrite into higher dimensional access is not possible for general
+/// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr:
+/// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we
+/// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`.
+/// The rewrite of the access would be a form not representable in Linalg:
+///   `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`.
+/// Note however that as `J` and `ii` iterate, the accesses do not have a
+/// particular alignment, so packing does not achieve alignment in this case
+///
+/// In the future, we may want to consider a mixed-form that allows some
+/// alignment in the presence of multiple accesses:
+///   `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]`
+/// And would rewrite accesses as:
+///   `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]`
+static FailureOr<SmallVector<std::optional<int64_t>>>
+packLinalgMetadataOnce(SmallVectorImpl<AffineMap> &indexingMaps,
+                       SmallVectorImpl<utils::IteratorType> &iteratorTypes,
+                       int64_t dim) {
+  int64_t newDim = iteratorTypes.size();
+  iteratorTypes.push_back(iteratorTypes[dim]);
+
+  SmallVector<std::optional<int64_t>> packedDimPerIndexingMap(
+      indexingMaps.size(), std::nullopt);
+  SmallVector<AffineMap> newMaps;
+  for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
+       ++operandIdx) {
+    AffineMap map = indexingMaps[operandIdx];
+
+    // Add the `newDim` to map whatever the case.
+    assert(map.getNumDims() == newDim && "num dims invariant violation");
+    map = map.shiftDims(1, newDim);
+
+    // Get the at-most-1 index of the result that is a function of `dim`.
+    // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which
+    // logically chunks dimension `dim` into `K * dim + newDim`, where the
+    // packing factor `K` is specified separately.
+    assert(hasAtMostOneResultFunctionOfDim(map, dim) &&
+           "num results invariant violation");
+    auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim);
+    if (!maybeOperandDimensionToPack.has_value()) {
+      newMaps.push_back(map);
+      continue;
+    }
+
+    // We can only pack AffineDimExpr atm.
+    if (!map.getResult(maybeOperandDimensionToPack.value())
+             .isa<AffineDimExpr>())
+      return failure();
+
+    // Add `newDim` to the results of the map.
+    map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim),
+                           map.getNumResults());
+    newMaps.push_back(map);
+
+    // Record the that `operandIdx` is packed.
+    packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
+  }
+  indexingMaps = newMaps;
+
+  return packedDimPerIndexingMap;
+}
+
+namespace {
+
+/// Helper struct to encode packing along one dimension of a LinalgOp.
+struct PackedOperandsDim {
+  OpFoldResult packedSize;
+  SmallVector<std::optional<int64_t>> packedDimForEachOperand;
+};
+
+/// Helper struct to encode packing along all dimensions of a LinalgOp.
+struct PackedOperandsDimList {
+  void push_back(PackedOperandsDim &&packedOperandsDims) {
+    spec.emplace_back(packedOperandsDims);
+  }
+  /// Return all the dims that have been packed for operand @ `operandPos`.
+  SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
+  /// Return all the pack sizes by which an operand @ `operandPos` is packed.
+  SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
+
+private:
+  SmallVector<PackedOperandsDim> spec;
+};
+
+} // namespace
+
+SmallVector<int64_t>
+PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
+  SmallVector<int64_t> res;
+  for (int64_t i = 0, e = spec.size(); i < e; ++i) {
+    if (!spec[i].packedDimForEachOperand[operandPos].has_value())
+      continue;
+    res.push_back(spec[i].packedDimForEachOperand[operandPos].value());
+  }
+  return res;
+}
+
+SmallVector<OpFoldResult>
+PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
+  SmallVector<OpFoldResult> res;
+  for (int64_t i = 0, e = spec.size(); i < e; ++i) {
+    if (!spec[i].packedDimForEachOperand[operandPos].has_value())
+      continue;
+    res.push_back(spec[i].packedSize);
+  }
+  return res;
+}
+
+/// Implement packing of a single LinalgOp by performing packing by
+/// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator.
+/// Return the packed Linalg op on success, failure otherwise.
+FailureOr<linalg::LinalgOp> linalg::pack(RewriterBase &rewriter,
+                                         linalg::LinalgOp linalgOp,
+                                         ArrayRef<OpFoldResult> packedSizes) {
+  if (packedSizes.size() != linalgOp.getNumLoops()) {
+    return rewriter.notifyMatchFailure(linalgOp,
+                                       "incorrect number of pack sizes");
+  }
+
+  Location loc = linalgOp->getLoc();
+  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
+  SmallVector<utils::IteratorType> iteratorTypes =
+      linalgOp.getIteratorTypesArray();
+  LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n";
+             llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
+             llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: ");
+             DBGSNL(););
+
+  // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i].
+  PackedOperandsDimList listOfPackedOperandsDim;
+  for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
+    std::optional<int64_t> maybeConstant = getConstantIntValue(packedSizes[i]);
+    // Skip tile sizes explicitly set to 0.
+    if (maybeConstant.has_value() && maybeConstant.value() == 0)
+      continue;
+
+    PackedOperandsDim packedOperandsDims;
+    packedOperandsDims.packedSize = packedSizes[i];
+    FailureOr<SmallVector<std::optional<int64_t>>>
+        maybePackedDimForEachOperand =
+            packLinalgMetadataOnce(indexingMaps, iteratorTypes, i);
+    if (failed(maybePackedDimForEachOperand))
+      return failure();
+    packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
+    listOfPackedOperandsDim.push_back(std::move(packedOperandsDims));
+
+    LLVM_DEBUG(
+        DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i]
+               << "\n";
+        llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
+        llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); DBGSNL();
+        llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand,
+                              DBGS() << "packedDimForEachOperand: ");
+        DBGSNL(););
+  }
+
+  // Step 2. Propagate packing to all LinalgOp operands.
+  SmallVector<Value> inputsAndInits, results;
+  for (auto operandsList :
+       {linalgOp.getDpsInputOperands(), linalgOp.getDpsInitOperands()}) {
+    for (OpOperand *opOperandPtr : operandsList) {
+      int64_t pos = opOperandPtr->getOperandNumber();
+      Value operand = opOperandPtr->get();
+      SmallVector<int64_t> innerPos =
+          listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
+      SmallVector<OpFoldResult> innerPackSizes =
+          listOfPackedOperandsDim.extractPackSizesForOperand(pos);
+      LLVM_DEBUG(
+          DBGS() << "operand: " << operand << "\n";
+          llvm::interleaveComma(innerPos, DBGS() << "innerPos: "); DBGSNL();
+          llvm::interleaveComma(innerPackSizes, DBGS() << "innerPackSizes: ");
+          DBGSNL(););
+      if (innerPackSizes.empty()) {
+        inputsAndInits.push_back(operand);
+        continue;
+      }
+      Value dest = tensor::PackOp::createDestinationTensor(
+          rewriter, loc, operand, innerPackSizes, innerPos,
+          /*outerDimsPerm=*/{});
+      // TODO: value of the padding attribute should be determined by consumers.
+      Attribute zeroAttr =
+          rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
+      Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
+      inputsAndInits.push_back(rewriter.create<tensor::PackOp>(
+          loc, operand, dest, innerPos, innerPackSizes, zero));
+    }
+  }
+
+  // Step 3. Build the packed op, use the type of `inits` as result types.
+  ValueRange inputs =
+      ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
+  ValueRange inits =
+      ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
+  auto packedLinalgOp = rewriter.create<linalg::GenericOp>(
+      linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps,
+      iteratorTypes);
+  packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
+
+  // Step 4. Propagate packing to all the op results.
+  for (OpResult result : packedLinalgOp->getResults()) {
+    int64_t resultNum = result.getResultNumber();
+    tensor::PackOp maybePackedInit =
+        inits[resultNum].getDefiningOp<tensor::PackOp>();
+    if (!maybePackedInit) {
+      results.push_back(result);
+      continue;
+    }
+    // Build the symmetrical UnPackOp to the existing PackOp.
+    results.push_back(rewriter.create<tensor::UnPackOp>(
+        packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
+        maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
+  }
+
+  // Step 5. Replace `linalgOp`.
+  rewriter.replaceOp(linalgOp, results);
+
+  // Return packedLinalgOp.
+  return cast<linalg::LinalgOp>(packedLinalgOp.getOperation());
+}


        


More information about the Mlir-commits mailing list