[Mlir-commits] [mlir] 1d1a331 - [mlir][Linalg] NFC - Expose packing transpose implementation as a standalone functional-style API call
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Jan 24 07:47:46 PST 2023
Author: Nicolas Vasilache
Date: 2023-01-24T07:47:40-08:00
New Revision: 1d1a3313513f5e15759328d27e4c2350977140c4
URL: https://github.com/llvm/llvm-project/commit/1d1a3313513f5e15759328d27e4c2350977140c4
DIFF: https://github.com/llvm/llvm-project/commit/1d1a3313513f5e15759328d27e4c2350977140c4.diff
LOG: [mlir][Linalg] NFC - Expose packing transpose implementation as a standalone functional-style API call
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 07f56770b63c8..6cb73bacb0e06 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -21,6 +21,7 @@
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallSet.h"
@@ -1141,12 +1142,32 @@ FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter);
-/// Implement packing of a single LinalgOp by performing packing by
-/// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator.
+/// Implement packing of a single LinalgOp by `packedSizes`.
+/// There must be one packedSizes entry per `linalgOp` iterator.
/// Return the packed Linalg op on success, failure otherwise.
FailureOr<linalg::LinalgOp> pack(RewriterBase &rewriter,
linalg::LinalgOp linalgOp,
ArrayRef<OpFoldResult> packedSizes);
+
+/// Struct to hold the result of a `packTranspose` call.
+struct PackTransposeResult {
+ tensor::PackOp transposedPackOp;
+ linalg::LinalgOp transposedLinalgOp;
+ tensor::UnPackOp transposedUnPackOp;
+};
+/// Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the
+/// transposed PackOp -> LinalgOp -> UnPackOp chain after replacements.
+/// Return failure if either:
+/// 1. the `packOp` does not have the `linalgOp` as its unique use.
+/// 2. the `maybeUnPackOp`, if specified must be a consumer of the result tied
+/// to the unique `packOp` use.
+/// 3. `outerPerm` (resp. `innerPerm`) must be valid permutations of
+/// `packOp.getOuterDimsPerm` (resp. `packOp.getInnerDimsPerm`) or empty.
+FailureOr<PackTransposeResult>
+packTranspose(RewriterBase &rewriter, tensor::PackOp packOp,
+ linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
+ ArrayRef<int64_t> outerPerm, ArrayRef<int64_t> innerPerm);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4e0be5aa8dbd4..554328c6cbc1c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -950,6 +950,23 @@ void transform::PackOp::getEffects(
// PackTransposeOp
//===---------------------------------------------------------------------===//
+LogicalResult transform::PackTransposeOp::verify() {
+ if (!isPermutationVector(getInnerPerm())) {
+ return emitOpError() << getInnerPermAttrName()
+ << " is not a valid permutation";
+ }
+ if (!isPermutationVector(getOuterPerm())) {
+ return emitOpError() << getOuterPermAttrName()
+ << " is not a valid permutation";
+ }
+ if (getInnerPerm().empty() && getOuterPerm().empty()) {
+ return emitOpError() << " at least one of " << getInnerPermAttrName()
+ << " or " << getOuterPermAttrName()
+ << " must be specified";
+ }
+ return success();
+}
+
namespace {
enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
} // namespace
@@ -983,84 +1000,6 @@ bool isValidPackingPermutation(
isPermutationVector(permutation);
}
-/// Return a copy of `tensorType` after permutation by `permutationVector`.
-// Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
-// but this would introduce a dependence on Dialect in IR.
-// TODO: Restructure.
-static RankedTensorType permuteShape(RankedTensorType tensorType,
- ArrayRef<int64_t> permutationVector) {
- SmallVector<int64_t> shape(tensorType.getShape());
- applyPermutationToVector(shape, permutationVector);
- return RankedTensorType::Builder(tensorType).setShape(shape);
-}
-
-/// Return a new GenericOp obtained by transposing opOperand by the permutation
-/// vector:
-/// - the corresponding indexing map is transposed by `permutation`
-/// - the corresponding operand value is replaced by `transposedValue`
-/// `linalgOp` is replaced by the return op in the process.
-/// Asserts that `transposedValue` is of the proper transposed ShapedType.
-static LinalgOp transposeOneLinalgOperandAndReplace(
- RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
- ArrayRef<int64_t> permutation, Value transposedValue) {
- // Sanity check the operand.
- assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
-
- // Sanity check of the expected transposed tensor type.
- auto tensorType = permuteShape(
- opOperand.get().getType().cast<RankedTensorType>(), permutation);
- (void)tensorType;
- assert(tensorType == transposedValue.getType() &&
- "expected tensor type mismatch");
-
- // Compute the transposed indexing map.
- // Sigh unsigned pollution.
- SmallVector<unsigned> tmpTransposition = llvm::to_vector(
- llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; }));
- AffineMap permutationMap =
- AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext());
- AffineMap transposedMap =
- permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
-
- // Set the transposed indexing map in the proper position.
- SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
- indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
- // Set the transposedValue in the proper operand position.
- SmallVector<Value> operands = linalgOp->getOperands();
- operands[opOperand.getOperandNumber()] = transposedValue;
-
- ValueRange operandsRef(operands);
- auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
- /*location=*/linalgOp->getLoc(),
- /*resultTensorTypes=*/
- operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
- /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
- /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
- /*indexingMaps=*/indexingMaps,
- /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
- transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
- rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
-
- return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
-}
-
-LogicalResult transform::PackTransposeOp::verify() {
- if (!isPermutationVector(getInnerPerm())) {
- return emitOpError() << getInnerPermAttrName()
- << " is not a valid permutation";
- }
- if (!isPermutationVector(getOuterPerm())) {
- return emitOpError() << getOuterPermAttrName()
- << " is not a valid permutation";
- }
- if (getInnerPerm().empty() && getOuterPerm().empty()) {
- return emitOpError() << " at least one of " << getInnerPermAttrName()
- << " or " << getOuterPermAttrName()
- << " must be specified";
- }
- return success();
-}
-
DiagnosedSilenceableFailure
transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
transform::TransformState &state) {
@@ -1138,68 +1077,23 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
assert(packOp && linalgOp && "unexpected null op");
// Step 3. Actually transpose the ops.
- Location loc = linalgOp.getLoc();
IRRewriter rewriter(getContext());
+ FailureOr<PackTransposeResult> res = packTranspose(
+ rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
+ // Preconditions have been checked, it is an error to fail here.
+ assert(succeeded(res) && "unexpected packTranspose failure");
- // Step 3.a. Transpose packOp.
- rewriter.setInsertionPoint(packOp);
- tensor::PackOp transposedPackOp = packOp.createTransposedClone(
- rewriter, loc, getInnerPerm(), getOuterPerm());
-
- // Step 3.b. Transpose linalgOp.
- assert(packOp.getResult().hasOneUse() && "expect single use");
- // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
- // identity. Don't rely on it.
- int64_t numLeadingDims = packOp.getSourceRank();
- int64_t numTrailingDims = packOp.getInnerDimsPos().size();
- // Step 3.b.i. Compute the permutation on the whole operand.
- // Leading part just reuse the outerPerm.
- SmallVector<int64_t> permutation(getOuterPerm());
- if (permutation.empty())
- llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
- // Trailing part needs to reindex positions by `numLeadingDims`.
- if (getInnerPerm().empty()) {
- llvm::append_range(
- permutation,
- llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
- } else {
- llvm::append_range(permutation,
- llvm::map_range(getInnerPerm(), [&](int64_t pos) {
- return numLeadingDims + pos;
- }));
- }
- assert(isPermutationVector(permutation) && "invalid permutation");
- // Step 3.b.ii. Save the transposedPackUse operand number in case we need to
- // get the tied OpResult after `linalgOp` has been replaced.
- OpOperand &packUse = *(packOp.getResult().getUses().begin());
- int64_t packUseOperandNumber = packUse.getOperandNumber();
- // Step 3.b.iii. Actually perform the transposition.
- rewriter.setInsertionPoint(linalgOp);
- linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
- rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
-
- // Step 3.c. Maybe transpose unPackOp.
- tensor::UnPackOp transposedUnPackOp;
- if (unPackOp) {
- OpOperand &opOperand =
- transposedLinalgOp->getOpOperand(packUseOperandNumber);
- OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
- rewriter.setInsertionPoint(unPackOp);
- transposedUnPackOp = unPackOp.createTransposedClone(
- rewriter, loc, transposedResult, getInnerPerm(), getOuterPerm());
- }
-
- // Step 4. Replace and return results.
- rewriter.replaceOp(packOp, transposedPackOp->getResults());
- transformResults.set(getPackOp().cast<OpResult>(), {transposedPackOp});
- // transposedLinalgOp was replaced in `transposeOneLinalgOperandAndReplace`.
- transformResults.set(getPackedOp().cast<OpResult>(), {transposedLinalgOp});
+ // Step 4. Return results.
+ transformResults.set(getPackOp().cast<OpResult>(), {res->transposedPackOp});
+ transformResults.set(getPackedOp().cast<OpResult>(),
+ {res->transposedLinalgOp});
if (unPackOp) {
- rewriter.replaceOp(unPackOp, transposedUnPackOp->getResults());
- transformResults.set(getUnPackOp().cast<OpResult>(), {transposedUnPackOp});
+ transformResults.set(getUnPackOp().cast<OpResult>(),
+ {res->transposedUnPackOp});
} else {
transformResults.set(getUnPackOp().cast<OpResult>(), {});
}
+
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 8e8657d64e851..4799b6e4e98c7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1161,3 +1161,147 @@ FailureOr<linalg::LinalgOp> linalg::pack(RewriterBase &rewriter,
// Return packedLinalgOp.
return cast<linalg::LinalgOp>(packedLinalgOp.getOperation());
}
+
+//===----------------------------------------------------------------------===//
+// packTranspose transformation.
+//===----------------------------------------------------------------------===//
+
+/// Return a copy of `tensorType` after permutation by `permutationVector`.
+// Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
+// but this would introduce a dependence on Dialect in IR.
+// TODO: Restructure.
+static RankedTensorType permuteShape(RankedTensorType tensorType,
+ ArrayRef<int64_t> permutationVector) {
+ SmallVector<int64_t> shape(tensorType.getShape());
+ applyPermutationToVector(shape, permutationVector);
+ return RankedTensorType::Builder(tensorType).setShape(shape);
+}
+
+/// Return a new GenericOp obtained by transposing opOperand by the permutation
+/// vector:
+/// - the corresponding indexing map is transposed by `permutation`
+/// - the corresponding operand value is replaced by `transposedValue`
+/// `linalgOp` is replaced by the return op in the process.
+/// Asserts that `transposedValue` is of the proper transposed ShapedType.
+static LinalgOp transposeOneLinalgOperandAndReplace(
+ RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
+ ArrayRef<int64_t> permutation, Value transposedValue) {
+ // Sanity check the operand.
+ assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
+
+ // Sanity check of the expected transposed tensor type.
+ auto tensorType = permuteShape(
+ opOperand.get().getType().cast<RankedTensorType>(), permutation);
+ (void)tensorType;
+ assert(tensorType == transposedValue.getType() &&
+ "expected tensor type mismatch");
+
+ // Compute the transposed indexing map.
+ // Sigh unsigned pollution.
+ SmallVector<unsigned> tmpTransposition = llvm::to_vector(
+ llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; }));
+ AffineMap permutationMap =
+ AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext());
+ AffineMap transposedMap =
+ permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
+
+ // Set the transposed indexing map in the proper position.
+ SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
+ indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
+ // Set the transposedValue in the proper operand position.
+ SmallVector<Value> operands = linalgOp->getOperands();
+ operands[opOperand.getOperandNumber()] = transposedValue;
+
+ ValueRange operandsRef(operands);
+ auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
+ /*location=*/linalgOp->getLoc(),
+ /*resultTensorTypes=*/
+ operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
+ /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
+ /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
+ /*indexingMaps=*/indexingMaps,
+ /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
+ transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
+ rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
+
+ return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
+}
+
+FailureOr<PackTransposeResult>
+linalg::packTranspose(RewriterBase &rewriter, tensor::PackOp packOp,
+ linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
+ ArrayRef<int64_t> outerPerm,
+ ArrayRef<int64_t> innerPerm) {
+ Location loc = linalgOp.getLoc();
+
+ // Step 1. Transpose packOp.
+ rewriter.setInsertionPoint(packOp);
+ tensor::PackOp transposedPackOp =
+ packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
+
+ if (!packOp.getResult().hasOneUse())
+ return rewriter.notifyMatchFailure(linalgOp, "expect single pack use");
+
+ OpOperand &packUse = *packOp->getUses().begin();
+ if (packUse.getOwner() != linalgOp) {
+ return rewriter.notifyMatchFailure(
+ linalgOp, "not a single use by the LinalgOp target");
+ }
+ if (maybeUnPackOp &&
+ (!linalgOp.isDpsInit(&packUse) ||
+ maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
+ return rewriter.notifyMatchFailure(linalgOp,
+ "not produced by the LinalgOp target");
+ }
+
+ // Step 2. Transpose linalgOp.
+ // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
+ // identity. Don't rely on it.
+ int64_t numLeadingDims = packOp.getSourceRank();
+ int64_t numTrailingDims = packOp.getInnerDimsPos().size();
+ // Step 2.a. Compute the permutation on the whole operand.
+ // Leading part just reuse the outerPerm.
+ SmallVector<int64_t> permutation(outerPerm);
+ if (permutation.empty())
+ llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
+ // Trailing part needs to reindex positions by `numLeadingDims`.
+ if (innerPerm.empty()) {
+ llvm::append_range(
+ permutation,
+ llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
+ } else {
+ llvm::append_range(permutation,
+ llvm::map_range(innerPerm, [&](int64_t pos) {
+ return numLeadingDims + pos;
+ }));
+ }
+ if (!isPermutationVector(permutation))
+ return rewriter.notifyMatchFailure(linalgOp, "invalid permutation");
+
+ // Step 2.b. Save the transposedPackUse operand number in case we need to
+ // get the tied OpResult after `linalgOp` has been replaced.
+ int64_t packUseOperandNumber = packUse.getOperandNumber();
+ // Step 2.c. Actually perform the transposition.
+ rewriter.setInsertionPoint(linalgOp);
+ linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
+ rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
+
+ // Step 3. Maybe transpose unPackOp.
+ tensor::UnPackOp transposedUnPackOp;
+ if (maybeUnPackOp) {
+ OpOperand &opOperand =
+ transposedLinalgOp->getOpOperand(packUseOperandNumber);
+ OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
+ rewriter.setInsertionPoint(maybeUnPackOp);
+ transposedUnPackOp = maybeUnPackOp.createTransposedClone(
+ rewriter, loc, transposedResult, innerPerm, outerPerm);
+
+ rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
+ }
+
+ // Step 4. Finally, replace packOp now that we don't need it anymore.
+ rewriter.replaceOp(packOp, transposedPackOp->getResults());
+
+ return PackTransposeResult{transposedPackOp, transposedLinalgOp,
+ transposedUnPackOp};
+}
More information about the Mlir-commits
mailing list