[Mlir-commits] [mlir] 7cedd95 - [mlir][Linalg] Add a transform.structured.lower_unpack op
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Feb 1 02:33:31 PST 2023
Author: Nicolas Vasilache
Date: 2023-02-01T02:26:44-08:00
New Revision: 7cedd956d0602bf6aca24c1fa5eda2bd585750ac
URL: https://github.com/llvm/llvm-project/commit/7cedd956d0602bf6aca24c1fa5eda2bd585750ac
DIFF: https://github.com/llvm/llvm-project/commit/7cedd956d0602bf6aca24c1fa5eda2bd585750ac.diff
LOG: [mlir][Linalg] Add a transform.structured.lower_unpack op
This revision introduces `transform.structured.lower_unpack` which allows
rewriting a `tensor.unpack` to `transpose` (`linalg.generic`) + `tensor.empty` + `tensor.collapse_shape` + `tensor.extract_slice`
The implementation is currently limited to static pack ops that do not have outer_dims permutations.
Differential Revision: https://reviews.llvm.org/D142889
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-lower-pack.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 5bd5c034e1c5..8ac2c2e0ad98 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -25,6 +25,7 @@ class LinalgOp;
namespace tensor {
class PackOp;
+class UnPackOp;
} // namespace tensor
namespace transform {
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 2aee320b8193..630000c57f8f 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -229,7 +229,7 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
#### Return modes
This operation ignores non-pack ops and drops them in the return.
- This operation produces a silenceableFailure if the padding fails for any
+ This operation produces a silenceableFailure if the rewrite fails for any
reason.
If all the operations referred to by the `target` are rewritten, the
transform succeeds.
@@ -252,6 +252,45 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
}];
}
+//===----------------------------------------------------------------------===//
+// LowerUnPackOp
+//===----------------------------------------------------------------------===//
+def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ TransformEachOpTrait,
+ TransformOpInterface]> {
+ let description = [{
+ Lower a tensor.unpack into empty + linalg.transpose + tensor.collapse_shape +
+ tensor.extract_slice.
+
+ #### Return modes
+
+ This operation ignores non-unpack ops and drops them in the return.
+ This operation produces a silenceableFailure if the rewrite fails for any
+ reason.
+ If all the operations referred to by the `target` are rewritten, the
+ transform succeeds.
+ Return handles to the newly produced empty, transpose, collapse_shape and extract_slice ops.
+ }];
+
+ let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target);
+ let results = (outs Transform_ConcreteOpType<"tensor.empty">:$empty_op,
+ Transform_ConcreteOpType<"linalg.transpose">:$transpose_op,
+ Transform_ConcreteOpType<"tensor.collapse_shape">:$collapse_shape_op,
+ Transform_ConcreteOpType<"tensor.extract_slice">:$extract_slice_op);
+ let assemblyFormat = [{
+ $target attr-dict `:` functional-type(operands, results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::tensor::UnPackOp target,
+ ::mlir::transform::ApplyToEachResultList &transformResults,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// MatchOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9952bb1cc5ac..022c94e39e23 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -782,8 +782,8 @@ struct LowerPackResult {
};
/// Rewrite pack as pad + reshape + transpose.
-static FailureOr<LowerPackResult> rewriteLowerPack(RewriterBase &rewriter,
- tensor::PackOp packOp) {
+static FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
+ tensor::PackOp packOp) {
// 1. Filter out NYI cases.
if (!packOp.getOuterDimsPerm().empty())
return rewriter.notifyMatchFailure(packOp, "outer dims perm NYI");
@@ -822,7 +822,7 @@ static FailureOr<LowerPackResult> rewriteLowerPack(RewriterBase &rewriter,
packingMetadata.reassociations);
Value paddingValue = packOp.getPaddingValue();
if (!paddingValue) {
- paddingValue = rewriter.create<arith::ConstantOp>(
+ rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
}
auto padOp =
@@ -876,7 +876,7 @@ DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
transform::TransformState &state) {
IRRewriter rewriter(target->getContext());
rewriter.setInsertionPoint(target);
- FailureOr<LowerPackResult> res = rewriteLowerPack(rewriter, target);
+ FailureOr<LowerPackResult> res = lowerPack(rewriter, target);
if (failed(res)) {
Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error);
diag << "cannot lower to pad + expand + transpose";
@@ -888,6 +888,117 @@ DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// LowerUnPackOp
+//===----------------------------------------------------------------------===//
+
+struct LowerUnPackOpResult {
+ tensor::EmptyOp emptyOp;
+ linalg::TransposeOp transposeOp;
+ tensor::CollapseShapeOp collapseShapeOp;
+ tensor::ExtractSliceOp extractSliceOp;
+};
+
+/// Rewrite pack as empty + transpose + reshape + extract_slice.
+static FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
+ tensor::UnPackOp unPackOp) {
+ // 1. Filter out NYI cases.
+ if (!unPackOp.getOuterDimsPerm().empty())
+ return rewriter.notifyMatchFailure(unPackOp, "outer dims perm NYI");
+
+ RankedTensorType packedTensorType = unPackOp.getSourceType();
+ if (!packedTensorType.hasStaticShape()) {
+ return rewriter.notifyMatchFailure(
+ unPackOp,
+ "non-static shape NYI, needs a more powerful tensor.expand_shape op");
+ }
+
+ Location loc = unPackOp->getLoc();
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(unPackOp);
+
+ // 2. Compute the permutation vector to move the last `numPackedDims` into the
+ // `innerPosDims` of a shape of rank `packedRank`.
+ int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
+ int64_t packedRank = packedTensorType.getRank();
+ auto lastDims = llvm::to_vector(
+ llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
+ PackingMetadata packingMetadata =
+ computePackingMetadata(packedRank, unPackOp.getInnerDimsPos());
+ SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
+ packedRank, lastDims, packingMetadata.insertPositions);
+
+ // 3. Compute the stripMinedShape: this is the packed shape without outer and
+ // inner permutations.
+ SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
+ applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
+
+ // 4. Transpose packedShape to stripMinedShape.
+ RankedTensorType stripMinedTensorType =
+ RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
+ RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+ stripMinedTensorType, packingMetadata.reassociations);
+ auto emptyOp =
+ rewriter.create<tensor::EmptyOp>(loc, stripMinedTensorType, ValueRange{});
+ auto transposeOp = rewriter.create<linalg::TransposeOp>(
+ loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
+
+ LLVM_DEBUG(
+ DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
+ DBGS() << "insertPositions: ");
+ DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
+ DBGS() << "packedShape: ");
+ DBGSNL();
+ llvm::interleaveComma(lastDimsToInsertPositionsPerm,
+ DBGS() << "lastDimsToInsertPositionsPerm: ");
+ DBGSNL(); llvm::interleaveComma(
+ packingMetadata.reassociations, DBGS() << "reassociations: ",
+ [&](ReassociationIndices ri) {
+ llvm::interleaveComma(ri, llvm::dbgs() << "|");
+ });
+ DBGSNL();
+ llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
+ DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
+
+ // 5. Collapse from the stripMinedShape to the padded result.
+ auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>(
+ loc, collapsedType, transposeOp->getResult(0),
+ packingMetadata.reassociations);
+
+ // 6. ExtractSlice
+ auto destTensorType = unPackOp.getDest().getType().cast<RankedTensorType>();
+ int64_t destRank = destTensorType.getRank();
+ OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
+ auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
+ loc, destTensorType, reshapeOp->getResult(0),
+ SmallVector<OpFoldResult>(destRank, zero),
+ tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)),
+ SmallVector<OpFoldResult>(destRank, one));
+
+ // 7. Replace unPackOp by transposeOp.
+ rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
+
+ return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
+}
+
+DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
+ tensor::UnPackOp target, transform::ApplyToEachResultList &transformResults,
+ transform::TransformState &state) {
+ IRRewriter rewriter(target->getContext());
+ rewriter.setInsertionPoint(target);
+ FailureOr<LowerUnPackOpResult> res = lowerUnPack(rewriter, target);
+ if (failed(res)) {
+ Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error);
+ diag << "cannot rewrite to pad + expand + transpose";
+ return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+ }
+ transformResults.push_back(res->emptyOp);
+ transformResults.push_back(res->transposeOp);
+ transformResults.push_back(res->collapseShapeOp);
+ transformResults.push_back(res->extractSliceOp);
+ return DiagnosedSilenceableFailure::success();
+}
+
//===---------------------------------------------------------------------===//
// MatchOp
//===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 1b87903bba59..7a89cf90214c 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -1,5 +1,6 @@
-// RUN: mlir-opt %s -test-transform-dialect-interpreter | FileCheck %s
+// RUN: mlir-opt %s -test-transform-dialect-interpreter --split-input-file | FileCheck %s
+ // CHECK-LABEL: func.func @pack(
func.func @pack(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
%cst_0 = arith.constant 0.0 : f32
@@ -26,3 +27,33 @@ transform.sequence failures(propagate) {
-> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
}
+// -----
+
+// CHECK-LABEL: func.func @unpack(
+func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
+ %cst_0 = arith.constant 0.0 : f32
+
+ // CHECK: tensor.empty() : tensor<17x8x2x32x16x16xf32>
+ // CHECK: linalg.transpose
+ // CHECK-SAME: ins(%{{.*}} : tensor<17x2x16x16x32x8xf32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<17x8x2x32x16x16xf32>)
+ // CHECK-SAME: permutation = [0, 5, 1, 4, 2, 3]
+ // CHECK: tensor.collapse_shape {{.*}}[0, 1], [2, 3], [4], [5]]
+ // CHECK-SAME: : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32>
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]
+ // CHECK-SAME: : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32>
+ %pack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
+ : tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32>
+ return %pack : tensor<129x47x16x16xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+ : (!pdl.operation) -> !transform.op<"tensor.unpack">
+ transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+ -> (!transform.op<"tensor.empty">,
+ !transform.op<"linalg.transpose">,
+ !transform.op<"tensor.collapse_shape">,
+ !transform.op<"tensor.extract_slice">)
+}
More information about the Mlir-commits
mailing list