[Mlir-commits] [mlir] 4ca52c6 - [mlir][Linalg] Add a transform.structured.lower_pack op
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Jan 31 10:06:16 PST 2023
Author: Nicolas Vasilache
Date: 2023-01-31T10:06:08-08:00
New Revision: 4ca52c6e7ed938ac20fa9845af9dbb7d46136226
URL: https://github.com/llvm/llvm-project/commit/4ca52c6e7ed938ac20fa9845af9dbb7d46136226
DIFF: https://github.com/llvm/llvm-project/commit/4ca52c6e7ed938ac20fa9845af9dbb7d46136226.diff
LOG: [mlir][Linalg] Add a transform.structured.lower_pack op
This revision introduces `transform.structured.lower_pack` which allows
rewriting a `tensor.pack` to `tensor.pad` + `tensor.expand_shape` + `linalg.transpose`.
The implementation is currently limited to static pack ops that do not have outer_dims permutations.
Differential Revision: https://reviews.llvm.org/D142881
Added:
mlir/test/Dialect/Linalg/transform-lower-pack.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index ae7c5cb5a91f6..5bd5c034e1c5b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -17,11 +17,16 @@
namespace mlir {
class TilingInterface;
class RewriterBase;
+
namespace linalg {
class GenericOp;
class LinalgOp;
} // namespace linalg
+namespace tensor {
+class PackOp;
+} // namespace tensor
+
namespace transform {
class TransformHandleTypeInterface;
// Types needed for builders.
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 043dc778b00fb..2aee320b81936 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -215,6 +215,43 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
}];
}
+//===----------------------------------------------------------------------===//
+// LowerPackOp
+//===----------------------------------------------------------------------===//
+def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ TransformEachOpTrait,
+ TransformOpInterface]> {
+ let description = [{
+ Rewrite a tensor.pack into tensor.pad + tensor.expand_shape + linalg.transpose.
+
+ #### Return modes
+
+ This operation ignores non-pack ops and drops them in the return.
+ This operation produces a silenceableFailure if the padding fails for any
+ reason.
+ If all the operations referred to by the `target` are rewritten, the
+ transform succeeds.
+ Return handles to the newly produced pad, expand_shape and transpose ops.
+ }];
+
+ let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target);
+ let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op,
+ Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op,
+ Transform_ConcreteOpType<"linalg.transpose">:$transpose_op);
+ let assemblyFormat = [{
+ $target attr-dict `:` functional-type(operands, results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::tensor::PackOp target,
+ ::mlir::transform::ApplyToEachResultList &transformResults,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// MatchOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6cb73bacb0e06..7fe9dfe2fbf5f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1142,12 +1142,17 @@ FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter);
+/// Struct to hold the result of a `pack` call.
+struct PackResult {
+ SmallVector<tensor::PackOp> packOps;
+ linalg::LinalgOp packedLinalgOp;
+ SmallVector<tensor::UnPackOp> unPackOps;
+};
/// Implement packing of a single LinalgOp by `packedSizes`.
/// There must be one packedSizes entry per `linalgOp` iterator.
/// Return the packed Linalg op on success, failure otherwise.
-FailureOr<linalg::LinalgOp> pack(RewriterBase &rewriter,
- linalg::LinalgOp linalgOp,
- ArrayRef<OpFoldResult> packedSizes);
+FailureOr<PackResult> pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
+ ArrayRef<OpFoldResult> packedSizes);
/// Struct to hold the result of a `packTranspose` call.
struct PackTransposeResult {
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index b055e2ccba9af..cc8bbd570ef66 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -115,7 +115,7 @@ FailureOr<int64_t> getConstantUpperBoundForIndex(Value value);
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
Value source, Value pad, bool nofold);
-/// Returns a GenericOp that tansposes `inputTensor` into `outputTensor` using
+/// Returns a GenericOp that transposes `inputTensor` into `outputTensor` using
/// `transposeVector` to permute the `inputTensor` dimensions.
GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
Value outputTensor,
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 209e70f53123b..9328b1f4b2b83 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1113,7 +1113,13 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
}]>
];
- let extraClassDeclaration = commonExtraClassDeclaration;
+ let extraClassDeclaration = commonExtraClassDeclaration # [{
+ static RankedTensorType
+ inferCollapsedType(RankedTensorType type, ArrayRef<AffineMap> reassociation);
+ static RankedTensorType
+ inferCollapsedType(RankedTensorType type,
+ SmallVector<ReassociationIndices> reassociation);
+ }];
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 2060dd7649781..9952bb1cc5ac8 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -14,10 +14,12 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
@@ -29,12 +31,13 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/TilingInterface.h"
-#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SetOperations.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
@@ -131,6 +134,81 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations(
return DiagnosedSilenceableFailure::success();
}
+/// Return a permutation vector of size permSize that would result in moving
+/// positions into desiredPositions.
+///
+/// For example, permSize == 5, positions = {2, 4}, desiredPositions = {1, 0}
+/// would result in a {4, 2, 0, 1, 3} permutation vector.
+static SmallVector<int64_t>
+computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
+ ArrayRef<int64_t> desiredPositions) {
+ SmallVector<int64_t> res(permSize, -1);
+ DenseSet<int64_t> seen;
+ for (auto [pos, desiredPos] : llvm::zip_equal(positions, desiredPositions)) {
+ res[desiredPos] = pos;
+ seen.insert(pos);
+ }
+ int64_t nextPos = 0;
+ for (int64_t &entry : res) {
+ if (entry != -1)
+ continue;
+ while (seen.contains(nextPos))
+ ++nextPos;
+ entry = nextPos;
+ ++nextPos;
+ }
+ return res;
+}
+
+struct PackingMetadata {
+ SmallVector<int64_t> insertPositions;
+ SmallVector<ReassociationIndices> reassociations;
+};
+/// Given a vector of `positions` indices representing desired packing insertion
+/// points into a target vector (i.e. pack/unpack.inner_dim_pos), compute the
+/// final positions in the target shape as well as the reshape reassociations.
+// Note: This should not be called with a large positions array (or the
+// implementation needs to be updated to use an N.log N sort instead of
+// repeated N^2 counts).
+static PackingMetadata computePackingMetadata(int64_t packedRank,
+ ArrayRef<int64_t> innerDimPos) {
+ PackingMetadata res;
+ res.insertPositions.reserve(innerDimPos.size());
+ // The pack insert position is the position + the number of previously
+ // inserted positions + offset.
+ // The offset controls whether the packing dimension is the first or last.
+ //
+ // Example
+ // =======
+ // Consider packing from a hypothetical ABCD layout to ABCDba whose
+ // pack.inner_dims is [1, 0]. The first step consists in undoing the
+ // permutation and producing AaBbCD. This is achieved purely by computing the
+ // insert positions of `b` and `a` into `ABCD`, starting from [1, 0]. One
+ // possibility, is to produce insert positions [2, 0], this would result in an
+ // aAbBCD layout (i.e. offset 0). The other possibility, is to produce insert
+ // positions [3, 1], this would result in an AaBbCD layout (i.e. offset 1).
+ // The latter is what we expect from packing.
+ int64_t offset = 1;
+ for (int64_t pos : innerDimPos) {
+ int64_t numInsertedBefore = llvm::count_if(
+ innerDimPos, [&pos](int64_t pos2) { return pos > pos2; });
+ res.insertPositions.push_back(pos + numInsertedBefore + offset);
+ }
+
+ DenseSet<int64_t> posSet(res.insertPositions.begin(),
+ res.insertPositions.end());
+ res.reassociations.reserve(packedRank);
+ for (int64_t i = 1; i <= packedRank; ++i) {
+ if (!posSet.contains(i)) {
+ res.reassociations.push_back(ReassociationIndices{i - 1});
+ continue;
+ }
+ res.reassociations.push_back(ReassociationIndices{i - 1, i});
+ ++i;
+ }
+ return res;
+}
+
//===----------------------------------------------------------------------===//
// DecomposeOp
//===----------------------------------------------------------------------===//
@@ -323,7 +401,7 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
Diagnostic &diag,
Operation *producerOp,
Operation *containingOp) {
- LLVM_DEBUG(llvm::dbgs() << "Try to fuse a direct extract use\n");
+ LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
if (!tileableProducer) {
diag.attachNote(producerOp->getLoc())
@@ -354,7 +432,7 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
// Tile the producer.
int64_t resultNumber =
sliceOpToTile.getSource().cast<OpResult>().getResultNumber();
- LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n");
+ LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
@@ -364,7 +442,7 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
<< "failed to tile producer op: " << *tileableProducer;
return nullptr;
}
- LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n");
+ LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledProducer << "\n");
// Replace the extract op.
Operation *fusedOp = tiledProducer->getDefiningOp();
@@ -388,8 +466,7 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
Operation *containingOp) {
- LLVM_DEBUG(
- llvm::dbgs() << "Try to fuse an extract use through block argument\n");
+ LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
if (!tileableProducer) {
@@ -442,7 +519,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
// Replace the use in the tileableProducer before tiling: clone, replace and
// then tile.
int64_t resultNumber = pUse->get().cast<OpResult>().getResultNumber();
- LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n");
+ LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
// Gather destination tensors.
SmallVector<Value> destinationTensors;
@@ -471,7 +548,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
<< "failed to tile producer op: " << *tileableProducer;
return nullptr;
}
- LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n");
+ LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledProducer << "\n");
// Replace the extract op.
Operation *fusedOp = tiledProducer->getDefiningOp();
@@ -496,7 +573,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
Operation *producerOp,
Operation *containingOp) {
- LLVM_DEBUG(llvm::dbgs() << "Try to fuse an use by cloning\n");
+ LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");
// Gather all uses inside the containing op.
SmallVector<OpOperand *> uses;
@@ -530,7 +607,7 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
"Parallel insert slice is not a valid clone destination");
unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
- LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n");
+ LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(use->getOwner());
@@ -607,8 +684,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
Operation *tiled =
tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
if (tiled) {
- LLVM_DEBUG(llvm::dbgs() << "\nFused a direct extract use\n"
- << *containingOp);
+ LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
fusedOps.push_back(tiled);
continue;
}
@@ -617,9 +693,8 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
rewriter, diag, producerOp, containingOp);
if (tiledContainingOpOperand) {
- LLVM_DEBUG(llvm::dbgs()
- << "\nFused an extract use through block argument\n"
- << *containingOp);
+ LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
+ << *containingOp);
fusedOps.push_back(tiledContainingOpOperand);
continue;
}
@@ -627,8 +702,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
Operation *cloned =
cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
if (cloned) {
- LLVM_DEBUG(llvm::dbgs() << "\nFused an use by cloning\n"
- << *containingOp);
+ LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
fusedOps.push_back(cloned);
continue;
}
@@ -697,6 +771,123 @@ LogicalResult transform::InterchangeOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// LowerPackOp
+//===----------------------------------------------------------------------===//
+
+struct LowerPackResult {
+ tensor::PadOp padOp;
+ tensor::ExpandShapeOp expandShapeOp;
+ linalg::TransposeOp transposeOp;
+};
+
+/// Rewrite pack as pad + reshape + transpose.
+static FailureOr<LowerPackResult> rewriteLowerPack(RewriterBase &rewriter,
+ tensor::PackOp packOp) {
+ // 1. Filter out NYI cases.
+ if (!packOp.getOuterDimsPerm().empty())
+ return rewriter.notifyMatchFailure(packOp, "outer dims perm NYI");
+
+ auto packedTensorType =
+ packOp->getResultTypes().front().cast<RankedTensorType>();
+ if (!packedTensorType.hasStaticShape()) {
+ return rewriter.notifyMatchFailure(
+ packOp,
+ "non-static shape NYI, needs a more powerful tensor.expand_shape op");
+ }
+
+ Location loc = packOp->getLoc();
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(packOp);
+
+ // 2. Compute the permutation vector to move the last `numPackedDims` into the
+ // `innerPosDims` of a shape of rank `packedRank`.
+ int64_t numPackedDims = packOp.getInnerDimsPos().size();
+ int64_t packedRank = packedTensorType.getRank();
+ auto lastDims = llvm::to_vector(
+ llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
+ PackingMetadata packingMetadata = computePackingMetadata(
+ packedTensorType.getRank(), packOp.getInnerDimsPos());
+ SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
+ packedRank, lastDims, packingMetadata.insertPositions);
+
+ // 3. Compute the stripMinedShape: this is the packed shape before any outer
+ // or inner permutations have been applied.
+ SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
+ applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
+
+ // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
+ RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
+ RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
+ packingMetadata.reassociations);
+ Value paddingValue = packOp.getPaddingValue();
+ if (!paddingValue) {
+ paddingValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
+ }
+ auto padOp =
+ tensor::createPadHighOp(collapsed, packOp.getSource(), paddingValue,
+ /*nofold=*/false, loc, rewriter);
+
+ LLVM_DEBUG(
+ DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
+ DBGS() << "insertPositions: ");
+ DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
+ DBGS() << "packedShape: ");
+ DBGSNL();
+ llvm::interleaveComma(lastDimsToInsertPositionsPerm,
+ DBGS() << "lastDimsToInsertPositionsPerm: ");
+ DBGSNL(); llvm::interleaveComma(
+ packingMetadata.reassociations, DBGS() << "reassociations: ",
+ [&](ReassociationIndices ri) {
+ llvm::interleaveComma(ri, llvm::dbgs() << "|");
+ });
+ DBGSNL();
+ llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
+ DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
+
+ // 5. Expand from the padded result to the stripMinedShape.
+ auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
+ loc,
+ RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
+ padOp.getResult(), packingMetadata.reassociations);
+
+ // 6. Transpose stripMinedShape to packedShape.
+ SmallVector<int64_t> insertPositionsToLastDimsPerm = computePermutationVector(
+ packedRank, packingMetadata.insertPositions, lastDims);
+ auto transposeOp = rewriter.create<linalg::TransposeOp>(
+ loc, reshapeOp.getResult(), packOp.getDest(),
+ insertPositionsToLastDimsPerm);
+
+ LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
+ DBGS() << "reshape op: " << reshapeOp; DBGSNL();
+ llvm::interleaveComma(insertPositionsToLastDimsPerm,
+ DBGS() << "insertPositionsToLastDimsPerm: ");
+ DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
+
+ // 7. Replace packOp by transposeOp.
+ rewriter.replaceOp(packOp, transposeOp->getResults());
+
+ return LowerPackResult{padOp, reshapeOp, transposeOp};
+}
+
+DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
+ tensor::PackOp target, transform::ApplyToEachResultList &transformResults,
+ transform::TransformState &state) {
+ IRRewriter rewriter(target->getContext());
+ rewriter.setInsertionPoint(target);
+ FailureOr<LowerPackResult> res = rewriteLowerPack(rewriter, target);
+ if (failed(res)) {
+ Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error);
+ diag << "cannot lower to pad + expand + transpose";
+ return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+ }
+ transformResults.push_back(res->padOp);
+ transformResults.push_back(res->expandShapeOp);
+ transformResults.push_back(res->transposeOp);
+ return DiagnosedSilenceableFailure::success();
+}
+
//===---------------------------------------------------------------------===//
// MatchOp
//===---------------------------------------------------------------------===//
@@ -931,12 +1122,12 @@ transform::PackOp::apply(transform::TransformResults &transformResults,
IRRewriter rewriter(linalgOp->getContext());
rewriter.setInsertionPoint(linalgOp);
- FailureOr<LinalgOp> maybeResult = pack(rewriter, linalgOp, packedSizes);
+ FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
if (failed(maybeResult))
return emitDefiniteFailure("data tiling failed");
transformResults.set(getPackedOp().cast<OpResult>(),
- maybeResult->getOperation());
+ maybeResult->packedLinalgOp.getOperation());
return DiagnosedSilenceableFailure::success();
}
@@ -1045,35 +1236,8 @@ static FailureOr<GemmDimsForPacking> getGemmDims(LinalgOp linalgOp) {
return GemmDimsForPacking{*ac.begin(), *bc.begin(), *ra.begin()};
}
-/// Return a permutation vector of size permSize that would result in moving
-/// positions into desiredPositions.
-///
-/// For example, permSize == 5, positions = {2, 4}, desiredPositions = {1, 0}
-/// would result in a {4, 2, 0, 1, 3} permutation vector.
-static SmallVector<int64_t>
-computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
- ArrayRef<int64_t> desiredPositions) {
- SmallVector<int64_t> res(permSize, -1);
- DenseSet<int64_t> seen;
- for (auto [pos, desiredPos] : llvm::zip(positions, desiredPositions)) {
- res[desiredPos] = pos;
- seen.insert(pos);
- }
- int64_t nextPos = 0;
- for (int64_t &entry : res) {
- if (entry != -1)
- continue;
- while (seen.contains(nextPos))
- ++nextPos;
- entry = nextPos;
- ++nextPos;
- }
- return res;
-}
-
-/// Pack a LinalgOp by greedily inferring gemm dimensions (m, n, k)
-/// where m and n are proper parallel dimensions and k is a proper reduction
-/// dimension.
+/// Pack a LinalgOp by greedily inferring gemm dimensions (m, n, k) where m and
+/// n are proper parallel dimensions and k is a proper reduction dimension.
/// Packing occurs by rewriting the op as a linalg.generic and calling
/// linalg::pack by `mnkPackedSizes`.
/// The order of the packed dimensions is customizable: the `mnkOrder` is a
@@ -1081,7 +1245,7 @@ computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
/// forms.
/// The outer dimensions of the operands are not permuted at this time, this is
/// left for future work.
-static FailureOr<LinalgOp>
+static FailureOr<PackResult>
packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
ArrayRef<OpFoldResult> mnkPackedSizes,
ArrayRef<int64_t> mnkOrder) {
@@ -1182,13 +1346,13 @@ PackGreedilyOp::apply(transform::TransformResults &transformResults,
rewriter.setInsertionPointAfter(linalgOp);
// Failing to pack greedily is perfectly fine.
// In the future we will want to order packings according to some metric.
- FailureOr<LinalgOp> gemm = packGemmGreedily(
+ FailureOr<PackResult> packResult = packGemmGreedily(
/*rewriter=*/rewriter,
/*linalgOp=*/linalgOp,
/*mnkPackedSizes=*/getMixedGemmPackedSizes(),
/*mnkOrder=*/getGemmInnerDimsOrder());
- if (succeeded(gemm)) {
- results.push_back(*gemm);
+ if (succeeded(packResult)) {
+ results.push_back(packResult->packedLinalgOp);
continue;
}
results.push_back(linalgOp);
@@ -1235,9 +1399,9 @@ namespace {
enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
} // namespace
-/// Return true if `permutation` is a valid permutation of the `outer_dims_perm`
-/// (case OuterOrInnerPerm::Outer) or `inner_dims_pos` (OuterOrInnerPerm::Inner)
-/// of the `tensor.pack` or `tensor.unpack` `op.
+/// Return true if `permutation` is a valid permutation of the
+/// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
+/// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
/// This is the case when the `permutation` rank matches the rank expected by
/// `op` and `permutation` is itself a permutation vector.
/// Return true if either `op` or `permutation` are empty to allow a simpler
@@ -1281,10 +1445,10 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
// Step 2. Bunch of runtime sanity check and error messages.
// Step 2.1. Fail on multi-op handles.
if (packOrUnpackOps.size() != 1 || linalgOps.size() != 1) {
- return emitSilenceableError()
- << "requires target to map to exactly 1 packing op and 1 packed op ("
- << "got " << packOrUnpackOps.size() << " and " << linalgOps.size()
- << ")";
+ return emitSilenceableError() << "requires target to map to exactly 1 "
+ "packing op and 1 packed op ("
+ << "got " << packOrUnpackOps.size() << " and "
+ << linalgOps.size() << ")";
}
// Step 2.2. Fail on wrong type.
@@ -1311,7 +1475,8 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
return emitSilenceableError() << errorMsg;
}
- // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical PackOp.
+ // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
+ // PackOp.
if (unPackOp) {
assert(!packOp && "packOp must be null on entry when unPackOp is not null");
OpOperand *packUse = linalgOp.getDpsInitOperand(
@@ -1700,9 +1865,9 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
}
if (second.size() != first.size() && !second.empty()) {
- auto diag =
- emitSilenceableError()
- << "splitting does not produce the second part for a subset of targets";
+ auto diag = emitSilenceableError()
+ << "splitting does not produce the second part for a subset "
+ "of targets";
diag.attachNote() << "expected splitting to produce the second part of all "
"or none of the targets";
diag.attachNote(noSecondPart->getLoc())
@@ -1965,7 +2130,8 @@ void transform::TileOp::build(OpBuilder &builder, OperationState &result,
Value target,
ArrayRef<OpFoldResult> mixedTileSizes,
ArrayRef<int64_t> interchange) {
- // Loop types are automaticaly splat by the callee, setting up one is enough.
+ // Loop types are automaticaly splat by the callee, setting up one is
+ // enough.
SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
build(builder, result, loopTypes, target, mixedTileSizes, interchange);
}
@@ -1978,8 +2144,8 @@ void transform::TileOp::build(OpBuilder &builder, OperationState &result,
SmallVector<Value> dynamicTileSizes;
dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
// Call the default builder which sets up the proper operands segment sizes
- // attributes for multiple variadic operands. In the absence of this, horrible
- // bugs ensue.
+ // attributes for multiple variadic operands. In the absence of this,
+ // horrible bugs ensue.
auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
unsigned numExpectedLoops =
staticTileSizes.size() - llvm::count(staticTileSizes, 0);
@@ -2247,8 +2413,8 @@ void transform::TileToForeachThreadOp::build(
SmallVector<Value> dynamicTileSizes;
dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
// Call the default builder which sets up the proper operands segment sizes
- // attributes for multiple variadic operands. In the absence of this, horrible
- // bugs ensue.
+ // attributes for multiple variadic operands. In the absence of this,
+ // horrible bugs ensue.
MLIRContext *ctx = builder.getContext();
auto operationType = pdl::OperationType::get(ctx);
auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
@@ -2284,8 +2450,8 @@ void transform::TileToForeachThreadOp::build(
dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
staticNumThreads);
// Call the default builder which sets up the proper operands segment sizes
- // attributes for multiple variadic operands. In the absence of this, horrible
- // bugs ensue.
+ // attributes for multiple variadic operands. In the absence of this,
+ // horrible bugs ensue.
MLIRContext *ctx = builder.getContext();
auto operationType = pdl::OperationType::get(ctx);
auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
@@ -2415,8 +2581,8 @@ LogicalResult TileToForeachThreadOp::verify() {
return emitOpError(
"tile_sizes and packed_tile_sizes are mutually exclusive");
if (numThreadsSpec == 0 && tileSizesSpec == 0)
- return emitOpError(
- "either (packed_)num_threads or (packed_)tile_sizes must be specified");
+ return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
+ "must be specified");
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 722bd9f2d8a36..f1f92d329ebb4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1052,9 +1052,9 @@ PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
/// Implement packing of a single LinalgOp by performing packing by
/// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator.
/// Return the packed Linalg op on success, failure otherwise.
-FailureOr<linalg::LinalgOp> linalg::pack(RewriterBase &rewriter,
- linalg::LinalgOp linalgOp,
- ArrayRef<OpFoldResult> packedSizes) {
+FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
+ linalg::LinalgOp linalgOp,
+ ArrayRef<OpFoldResult> packedSizes) {
if (packedSizes.size() != linalgOp.getNumLoops()) {
return rewriter.notifyMatchFailure(linalgOp,
"incorrect number of pack sizes");
@@ -1069,6 +1069,8 @@ FailureOr<linalg::LinalgOp> linalg::pack(RewriterBase &rewriter,
llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: ");
DBGSNL(););
+ SmallVector<tensor::PackOp> packOps;
+ SmallVector<tensor::UnPackOp> unPackOps;
// Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i].
PackedOperandsDimList listOfPackedOperandsDim;
for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
@@ -1124,8 +1126,9 @@ FailureOr<linalg::LinalgOp> linalg::pack(RewriterBase &rewriter,
Attribute zeroAttr =
rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
- inputsAndInits.push_back(rewriter.create<tensor::PackOp>(
+ packOps.push_back(rewriter.create<tensor::PackOp>(
loc, operand, dest, innerPos, innerPackSizes, zero));
+ inputsAndInits.push_back(packOps.back());
}
}
@@ -1149,16 +1152,19 @@ FailureOr<linalg::LinalgOp> linalg::pack(RewriterBase &rewriter,
continue;
}
// Build the symmetrical UnPackOp to the existing PackOp.
- results.push_back(rewriter.create<tensor::UnPackOp>(
+ unPackOps.push_back(rewriter.create<tensor::UnPackOp>(
packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
+ results.push_back(unPackOps.back());
}
// Step 5. Replace `linalgOp`.
rewriter.replaceOp(linalgOp, results);
// Return packedLinalgOp.
- return cast<linalg::LinalgOp>(packedLinalgOp.getOperation());
+ return PackResult{packOps,
+ cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
+ unPackOps};
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4515d711f72bf..7960b64fd7151 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1272,11 +1272,18 @@ SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
getReassociationIndices());
}
+RankedTensorType CollapseShapeOp::inferCollapsedType(
+ RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
+ return inferCollapsedType(
+ type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+ type.getContext(), reassociation)));
+}
+
/// Compute the RankedTensorType obtained by applying `reassociation` to
/// `type`.
-static RankedTensorType
-computeTensorReshapeCollapsedType(RankedTensorType type,
- ArrayRef<AffineMap> reassociation) {
+RankedTensorType
+CollapseShapeOp::inferCollapsedType(RankedTensorType type,
+ ArrayRef<AffineMap> reassociation) {
auto shape = type.getShape();
SmallVector<int64_t, 4> newShape;
newShape.reserve(reassociation.size());
@@ -1304,7 +1311,7 @@ computeTensorReshapeCollapsedType(RankedTensorType type,
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
- auto resultType = computeTensorReshapeCollapsedType(
+ auto resultType = inferCollapsedType(
src.getType().cast<RankedTensorType>(),
getSymbolLessAffineMaps(
convertReassociationIndicesToExprs(b.getContext(), reassociation)));
@@ -1336,7 +1343,7 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
auto maps = op.getReassociationMaps();
RankedTensorType expectedType =
- computeTensorReshapeCollapsedType(expandedType, maps);
+ CollapseShapeOp::inferCollapsedType(expandedType, maps);
if (!isSameTypesWithoutEncoding(collapsedType, expectedType))
return op.emitOpError("expected collapsed type to be ")
<< expectedType << ", but got " << collapsedType;
@@ -1436,7 +1443,7 @@ struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
RankedTensorType srcType =
castOp.getSource().getType().cast<RankedTensorType>();
- RankedTensorType newResultType = computeTensorReshapeCollapsedType(
+ RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
srcType, collapseShapeOp.getReassociationMaps());
if (newResultType == collapseShapeOp.getResultType()) {
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
new file mode 100644
index 0000000000000..1b87903bba59f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter | FileCheck %s
+
+func.func @pack(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
+ %cst_0 = arith.constant 0.0 : f32
+
+ // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+ // CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
+ // CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3], [4], [5]]
+ // CHECK-SAME: : tensor<136x64x16x16xf32> into tensor<17x8x2x32x16x16xf32>
+ // CHECK: linalg.transpose
+ // CHECK-SAME: ins(%{{.*}} : tensor<17x8x2x32x16x16xf32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<17x2x16x16x32x8xf32>)
+ // CHECK-SAME: permutation = [0, 2, 4, 5, 3, 1]
+ %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
+ : tensor<129x47x16x16xf32> -> tensor<17x2x16x16x32x8xf32>
+ return %pack : tensor<17x2x16x16x32x8xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ %pack = transform.structured.match ops{["tensor.pack"]} in %module_op
+ : (!pdl.operation) -> !transform.op<"tensor.pack">
+ transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
+ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}
+
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 7f4d34a4b916b..39649d7d45d30 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -8341,12 +8341,14 @@ cc_library(
":LinalgDialect",
":LinalgTransformOpsIncGen",
":LinalgTransforms",
+ ":LinalgUtils",
":PDLDialect",
":Parser",
":SCFTransforms",
":SideEffectInterfaces",
":Support",
":TensorDialect",
+ ":TensorUtils",
":TilingInterface",
":TransformDialect",
":TransformDialectUtils",
More information about the Mlir-commits
mailing list