[Mlir-commits] [mlir] 790f237 - [mlir][Linalg] Add a structured.pack_transpose transform op
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Jan 20 00:33:27 PST 2023
Author: Nicolas Vasilache
Date: 2023-01-20T00:30:16-08:00
New Revision: 790f237012259186ed4a767e29e85f5ba6720b59
URL: https://github.com/llvm/llvm-project/commit/790f237012259186ed4a767e29e85f5ba6720b59
DIFF: https://github.com/llvm/llvm-project/commit/790f237012259186ed4a767e29e85f5ba6720b59.diff
LOG: [mlir][Linalg] Add a structured.pack_transpose transform op
This transform is complementary to the `structured.pack` op which
allows packing a whole op but does not allow transposes on the individual
operands.
`structured.pack_transpose` allows transposing single operands connected to
pack or unpack ops after the fact.
This makes the system overall more composable than e.g. a giant transform
op with all permutation specified at once.
Differential Revision: https://reviews.llvm.org/D142053
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Linalg/transform-op-pack.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index a04c48f5fa4bf..9b51228382933 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -773,6 +773,9 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
LogicalResult reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes);
+ /// Return the index in the indexingMaps vector that corresponds to this `opOperand`
+ int64_t getIndexingMapIndex(OpOperand *opOperand);
+
//========================================================================//
// Forwarding functions to access interface methods from the
// DestinationStyleOpInterface.
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f6c601f73df8e..991aa041b9531 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -363,8 +363,11 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
}];
}
+//===----------------------------------------------------------------------===//
+// PackOp
+//===----------------------------------------------------------------------===//
def PackOp : Op<Transform_Dialect, "structured.pack", [
- TransformOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,]> {
let description = [{
Pack a LinalgOp by applying a data tiling transformation on the op and
@@ -439,14 +442,73 @@ def PackOp : Op<Transform_Dialect, "structured.pack", [
}];
let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure apply(
- transform::TransformResults &transformResults,
- transform::TransformState &state);
-
::llvm::SmallVector<::mlir::OpFoldResult> getMixedPackedSizes();
}];
}
+//===----------------------------------------------------------------------===//
+// PackTransposeOp
+//===----------------------------------------------------------------------===//
+def PackTransposeOp : Op<Transform_Dialect, "structured.pack_transpose", [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let description = [{
+ Apply a transposition to a single `tensor.pack` (resp. `tensor.unpack`) and
+ update the `linalg.generic` op that consumes (resp. produces) the operation.
+
+ This transform allows composing a simple `structured.pack` with additional
+ transpositions to e.g. match the data format required by a specific library
+ call or ISA instruction.
+
+ The transpose spec must specify at least one of `outer_perm` or `inner_perm`
+ attributes, which will act upon the `outer_dims_perm` or `inner_dims_pos` of
+ the specified `tensor.pack` or `tensor.unpack` op.
+
+ If the `target` of this op is a `tensor.pack` then a new `tensor.empty` will
+ be created along with transposed versions of the `tensor.pack` and the
+ consuming `linalg.generic`, which is expected to be the sole consumer.
+
+ If the `target` of this op is a `tensor.unpack` then the whole pack / compute
+ / unpack chain will be transposed and transposed clones of `tensor.pack`,
+ the consuming `linalg.generic` and the tail `tensor.pack` will be created.
+
+ #### Return modes
+
+ This operation targets a single `tensor.pack` / `tensor.unpack` op and a
+ single matching `linalg.generic` that consumes / produces the op. Otherwise,
+ it produces a silenceableFailure.
+
+ This operation may produce a silenceableFailure if the transpose spec is
+ ill-formed (i.e. `outer_perm` or `inner_perm` are not permutations of the
+ proper rank) or if the tranposition of all involved operations fails for any
+ reason.
+
+ This operation returns 3 handles, one to the transformed LinalgOp, one to
+ the transformed `tensor.pack` and one to the transformed `tensor.unpack`.
+ The last handle for `tensor.unpack` is empty if `target_pack_or_unpack_op`
+ was not itself a `tensor.unpack`.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target_pack_or_un_pack_op,
+ TransformHandleTypeInterface:$target_linalg_op,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_perm,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$inner_perm);
+ let results = (outs TransformHandleTypeInterface:$packed_op,
+ TransformHandleTypeInterface:$pack_op,
+ TransformHandleTypeInterface:$un_pack_op);
+ let assemblyFormat = [{
+ $target_pack_or_un_pack_op
+ `with_compute_op` `(` $target_linalg_op `)`
+ (`outer_perm` `=` $outer_perm^ )?
+ (`inner_perm` `=` $inner_perm^ )?
+ attr-dict
+ `:` functional-type(operands, results)
+ }];
+
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// PadOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index f0f12388186a5..1b7b17749b84a 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1776,6 +1776,21 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
static Value createDestinationTensor(OpBuilder &b, Location loc,
Value source, ArrayRef<OpFoldResult> innerTileSizes,
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
+
+ /// Build and return a new PackOp that is a clone of the current PackOp with
+ /// (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by
+ /// innerPermutation (resp. outerPermutation).
+ /// A new `tensor.empty` of the proper shape is built in the process.
+ /// Asserts that:
+ /// - At least one of innerPermutation or outerPermutation is non-empty.
+ /// - If not empty, innerPermutation is a valid permutation of size
+ /// matching innerDimPos.
+ /// - If not empty, outerPermutation is a valid permutation of size
+ /// matching outerDimsPerm.
+ PackOp createTransposedClone(OpBuilder &b,
+ Location loc,
+ ArrayRef<int64_t> innerPermutation,
+ ArrayRef<int64_t> outerPermutation);
}];
let hasCanonicalizeMethod = 1;
@@ -1832,7 +1847,23 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
];
- let extraClassDeclaration = commonExtraClassDeclaration;
+ let extraClassDeclaration = commonExtraClassDeclaration # [{
+ /// Build and return a new UnPackOp that is a clone of the current UnPackOp
+ /// with (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by
+ /// innerPermutation (resp. outerPermutation).
+ /// Asserts that:
+ /// - At least one of innerPermutation or outerPermutation is non-empty.
+ /// - If not empty, innerPermutation is a valid permutation of size
+ /// matching innerDimPos.
+ /// - If not empty, outerPermutation is a valid permutation of size
+ /// matching outerDimsPerm.
+ UnPackOp createTransposedClone(OpBuilder &b,
+ Location loc,
+ Value transposedSource,
+ ArrayRef<int64_t> innerPermutation,
+ ArrayRef<int64_t> outerPermutation);
+ }];
+
let hasCanonicalizeMethod = 1;
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index b5088717972a3..e5e0bdd255dc5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -621,6 +621,22 @@ LinalgOp::reifyResultShapes(OpBuilder &b,
return success();
}
+/// Return the index in the indexingMaps vector that corresponds to this
+/// `opOperand`.
+int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
+ auto operandNumber = opOperand->getOperandNumber();
+ auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
+ if (!dpsIface.isDpsInput(opOperand))
+ return operandNumber;
+ auto [start, end] = dpsIface.getDpsInitsPositionRange();
+ assert(!dpsIface.isDpsInit(opOperand));
+ // Account for potential inputs that are not DPS and may not appear in
+ // `indexingMaps`.
+ return cast<DestinationStyleOpInterface>(*this->getOperation())
+ .getNumDpsInputs() +
+ operandNumber - start;
+}
+
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
LinalgOp linalgOp = cast<LinalgOp>(op);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 6efb813590678..4247a6c5fda36 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -17,17 +17,21 @@
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/IR/TransformUtils.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
@@ -1161,16 +1165,12 @@ transform::PackOp::apply(transform::TransformResults &transformResults,
// Fail on multi-op handles.
auto linalgOp = dyn_cast<linalg::LinalgOp>(targetOps.front());
if (targetOps.size() != 1 || !linalgOp) {
- // TODO: remove this unnecessary set to empty once crashes are fixed.
- transformResults.set(getPackedOp().cast<OpResult>(), {});
return emitSilenceableError()
<< "requires target to map to exactly 1 LinalgOp (got "
<< targetOps.size() << ")";
}
// Fail on mismatched number of pack sizes.
if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
- // TODO: remove this unnecessary set to empty once crashes are fixed.
- transformResults.set(getPackedOp().cast<OpResult>(), {});
return emitSilenceableError()
<< "requires number of packed sizes match the number of loops ("
<< getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
@@ -1194,6 +1194,263 @@ void transform::PackOp::getEffects(
transform::consumesHandle(getTarget(), effects);
transform::onlyReadsHandle(getPackedSizes(), effects);
transform::producesHandle(getPackedOp(), effects);
+ transform::modifiesPayload(effects);
+}
+
+//===---------------------------------------------------------------------===//
+// PackTransposeOp
+//===---------------------------------------------------------------------===//
+
+namespace {
+enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
+} // namespace
+
+/// Return true if `permutation` is a valid permutation of the `outer_dims_perm`
+/// (case OuterOrInnerPerm::Outer) or `inner_dims_pos` (OuterOrInnerPerm::Inner)
+/// of the `tensor.pack` or `tensor.unpack` `op.
+/// This is the case when the `permutation` rank matches the rank expected by
+/// `op` and `permutation` is itself a permutation vector.
+/// Return true if either `op` or `permutation` are empty to allow a simpler
+/// polymorphic implementation.
+template <typename RelayoutOpTy>
+bool isValidPackingPermutation(
+ RelayoutOpTy op, ArrayRef<int64_t> permutation,
+ OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
+ static_assert(
+ llvm::is_one_of<RelayoutOpTy, tensor::PackOp, tensor::UnPackOp>::value,
+ "applies to only pack or unpack operations");
+ if (!op || permutation.empty())
+ return true;
+ int64_t innerRank = op.getInnerDimsPos().size();
+ if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
+ return permutation.size() == innerRank && isPermutationVector(permutation);
+ // op.getOuterDimsPerm() may be empty, in which case it is identity.
+ // Don't rely on it.
+ if (std::is_same<RelayoutOpTy, tensor::PackOp>::value) {
+ return permutation.size() == op.getSourceRank() &&
+ isPermutationVector(permutation);
+ }
+ return permutation.size() == op.getDestRank() &&
+ isPermutationVector(permutation);
+}
+
+/// Return a copy of `tensorType` after permutation by `permutationVector`.
+// Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
+// but this would introduce a dependence on Dialect in IR.
+// TODO: Restructure.
+static RankedTensorType permuteShape(RankedTensorType tensorType,
+ ArrayRef<int64_t> permutationVector) {
+ SmallVector<int64_t> shape(tensorType.getShape());
+ applyPermutationToVector(shape, permutationVector);
+ return RankedTensorType::Builder(tensorType).setShape(shape);
+}
+
+/// Return a new GenericOp obtained by transposing opOperand by the permutation
+/// vector:
+/// - the corresponding indexing map is transposed by `permutation`
+/// - the corresponding operand value is replaced by `transposedValue`
+/// `linalgOp` is replaced by the return op in the process.
+/// Asserts that `transposedValue` is of the proper transposed ShapedType.
+static LinalgOp transposeOneLinalgOperandAndReplace(
+ RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
+ ArrayRef<int64_t> permutation, Value transposedValue) {
+ // Sanity check the operand.
+ assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
+
+ // Sanity check of the expected transposed tensor type.
+ auto tensorType = permuteShape(
+ opOperand.get().getType().cast<RankedTensorType>(), permutation);
+ assert(tensorType == transposedValue.getType() &&
+ "expected tensor type mismatch");
+
+ // Compute the transposed indexing map.
+ // Sigh unsigned pollution.
+ SmallVector<unsigned> tmpTransposition = llvm::to_vector(
+ llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; }));
+ AffineMap permutationMap =
+ AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext());
+ AffineMap transposedMap =
+ permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
+
+ // Set the transposed indexing map in the proper position.
+ SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
+ indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
+ // Set the transposedValue in the proper operand position.
+ SmallVector<Value> operands = linalgOp->getOperands();
+ operands[opOperand.getOperandNumber()] = transposedValue;
+
+ ValueRange operandsRef(operands);
+ auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
+ /*location=*/linalgOp->getLoc(),
+ /*resultTensorTypes=*/
+ operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
+ /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
+ /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
+ /*indexingMaps=*/indexingMaps,
+ /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
+ transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
+ rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
+
+ return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
+}
+
+LogicalResult transform::PackTransposeOp::verify() {
+ if (!isPermutationVector(getInnerPerm())) {
+ return emitOpError() << getInnerPermAttrName()
+ << " is not a valid permutation";
+ }
+ if (!isPermutationVector(getOuterPerm())) {
+ return emitOpError() << getOuterPermAttrName()
+ << " is not a valid permutation";
+ }
+ if (getInnerPerm().empty() && getOuterPerm().empty()) {
+ return emitOpError() << " at least one of " << getInnerPermAttrName()
+ << " or " << getOuterPermAttrName()
+ << " must be specified";
+ }
+ return success();
+}
+
+DiagnosedSilenceableFailure
+transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
+ transform::TransformState &state) {
+ ArrayRef<Operation *> packOrUnpackOps =
+ state.getPayloadOps(getTargetPackOrUnPackOp());
+ ArrayRef<Operation *> linalgOps = state.getPayloadOps(getTargetLinalgOp());
+ // Step 1. If nothing to pack, propagate success.
+ if (packOrUnpackOps.empty()) {
+ transformResults.set(getPackedOp().cast<OpResult>(), {});
+ transformResults.set(getPackOp().cast<OpResult>(), {});
+ transformResults.set(getUnPackOp().cast<OpResult>(), {});
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ // Step 2. Bunch of runtime sanity check and error messages.
+ // Step 2.1. Fail on multi-op handles.
+ if (packOrUnpackOps.size() != 1 || linalgOps.size() != 1) {
+ return emitSilenceableError()
+ << "requires target to map to exactly 1 packing op and 1 packed op ("
+ << "got " << packOrUnpackOps.size() << " and " << linalgOps.size()
+ << ")";
+ }
+
+ // Step 2.2. Fail on wrong type.
+ auto packOp = dyn_cast<tensor::PackOp>(packOrUnpackOps.front());
+ auto unPackOp = dyn_cast<tensor::UnPackOp>(packOrUnpackOps.front());
+ if ((!packOp && !unPackOp)) {
+ return emitSilenceableError() << "requires target to map to a "
+ "tensor.pack or tensor.unpack";
+ }
+ LinalgOp linalgOpTarget = dyn_cast<linalg::LinalgOp>(linalgOps.front());
+ if (!linalgOpTarget)
+ return emitSilenceableError() << "requires a LinalgOp target";
+
+ // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
+ LinalgOp linalgOp;
+ if (packOp && packOp.getResult().hasOneUse())
+ linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
+ else if (unPackOp)
+ linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
+ if (linalgOp != linalgOpTarget) {
+ auto errorMsg =
+ packOp ? StringLiteral{"not a single use by the LinalgOp target"}
+ : StringLiteral{"not produced by the LinalgOp target"};
+ return emitSilenceableError() << errorMsg;
+ }
+
+ // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical PackOp.
+ if (unPackOp) {
+ assert(!packOp && "packOp must be null on entry when unPackOp is not null");
+ OpOperand *packUse = linalgOp.getDpsInitOperand(
+ unPackOp.getSource().cast<OpResult>().getResultNumber());
+ packOp = dyn_cast_or_null<tensor::PackOp>(packUse->get().getDefiningOp());
+ if (!packOp || !packOp.getResult().hasOneUse())
+ return emitSilenceableError() << "could not find matching pack op";
+ }
+
+ // Step 2.5. Fail if any permutation does not validate.
+ for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
+ ArrayRef<int64_t> perm =
+ (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
+ auto errorMsg = (permType == OuterOrInnerPerm::Outer)
+ ? StringLiteral{"invalid outer_perm"}
+ : StringLiteral{"invalid inner_perm"};
+ if (!isValidPackingPermutation(packOp, perm, permType) ||
+ !isValidPackingPermutation(unPackOp, perm, permType)) {
+ Operation *packOrUnpackOp =
+ unPackOp ? unPackOp.getOperation() : packOp.getOperation();
+ return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
+ }
+ }
+
+ // From here on, packOp and linalgOp are always present, unPackOp may or may
+ // not be present.
+ assert(packOp && linalgOp && "unexpected null op");
+
+ // Step 3. Actually transpose the ops.
+ Location loc = linalgOp.getLoc();
+ IRRewriter rewriter(getContext());
+
+ // Step 3.a. Transpose packOp.
+ rewriter.setInsertionPoint(packOp);
+ tensor::PackOp transposedPackOp = packOp.createTransposedClone(
+ rewriter, loc, getInnerPerm(), getOuterPerm());
+
+ // Step 3.b. Transpose linalgOp.
+ assert(packOp.getResult().hasOneUse() && "expect single use");
+ // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
+ // identity. Don't rely on it.
+ int64_t numLeadingDims = packOp.getSourceRank();
+ int64_t numTrailingDims = packOp.getInnerDimsPos().size();
+ // Step 3.b.i. Compute the permutation on the whole operand.
+ // Leading part just reuse the outerPerm.
+ SmallVector<int64_t> permutation(getOuterPerm());
+ if (permutation.empty())
+ llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
+ // Trailing part needs to reindex positions by `numLeadingDims`.
+ if (getInnerPerm().empty()) {
+ llvm::append_range(
+ permutation,
+ llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
+ } else {
+ llvm::append_range(permutation,
+ llvm::map_range(getInnerPerm(), [&](int64_t pos) {
+ return numLeadingDims + pos;
+ }));
+ }
+ assert(isPermutationVector(permutation) && "invalid permutation");
+ // Step 3.b.ii. Save the transposedPackUse operand number in case we need to
+ // get the tied OpResult after `linalgOp` has been replaced.
+ OpOperand &packUse = *(packOp.getResult().getUses().begin());
+ int64_t packUseOperandNumber = packUse.getOperandNumber();
+ // Step 3.b.iii. Actually perform the transposition.
+ rewriter.setInsertionPoint(linalgOp);
+ linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
+ rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
+
+ // Step 3.c. Maybe transpose unPackOp.
+ tensor::UnPackOp transposedUnPackOp;
+ if (unPackOp) {
+ OpOperand &opOperand =
+ transposedLinalgOp->getOpOperand(packUseOperandNumber);
+ OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
+ rewriter.setInsertionPoint(unPackOp);
+ transposedUnPackOp = unPackOp.createTransposedClone(
+ rewriter, loc, transposedResult, getInnerPerm(), getOuterPerm());
+ }
+
+ // Step 4. Replace and return results.
+ rewriter.replaceOp(packOp, transposedPackOp->getResults());
+ transformResults.set(getPackOp().cast<OpResult>(), {transposedPackOp});
+ // transposedLinalgOp was replaced in `transposeOneLinalgOperandAndReplace`.
+ transformResults.set(getPackedOp().cast<OpResult>(), {transposedLinalgOp});
+ if (unPackOp) {
+ rewriter.replaceOp(unPackOp, transposedUnPackOp->getResults());
+ transformResults.set(getUnPackOp().cast<OpResult>(), {transposedUnPackOp});
+ } else {
+ transformResults.set(getUnPackOp().cast<OpResult>(), {});
+ }
+ return DiagnosedSilenceableFailure::success();
}
//===---------------------------------------------------------------------===//
@@ -1359,7 +1616,7 @@ transform::ReplaceOp::apply(TransformResults &transformResults,
if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
target->getNumRegions() > 0)
return emitDefiniteFailure()
- << "expected target that is isloated from above";
+ << "expected target that is isolated from above";
}
// Clone and replace.
@@ -1907,32 +2164,31 @@ transform::TileOp::apply(TransformResults &transformResults,
scf::SCFTilingOptions tilingOptions;
unsigned index = en.index();
if (!tileSizes.empty()) {
- tilingOptions.setTileSizeComputationFunction(
- [&, index](OpBuilder &b, Operation *) {
- SmallVector<Value, 4> sizes;
- sizes.reserve(tileSizes.size());
- unsigned dynamicIdx = 0;
- for (OpFoldResult ofr : getMixedSizes()) {
- if (auto attr = ofr.dyn_cast<Attribute>()) {
- sizes.push_back(b.create<arith::ConstantIndexOp>(
- getLoc(), attr.cast<IntegerAttr>().getInt()));
- continue;
- }
- ArrayRef<Operation *> dynamicSizes =
- dynamicSizeProducers[dynamicIdx];
- ArrayRef<int64_t> params = paramSizes[dynamicIdx];
- ++dynamicIdx;
- assert((dynamicSizes.empty() ^ params.empty()) &&
- "expected either dynamic sizes or parameters");
- if (!params.empty()) {
- sizes.push_back(
- b.create<arith::ConstantIndexOp>(getLoc(), params[index]));
- } else {
- sizes.push_back(dynamicSizes[index]->getResult(0));
- }
- }
- return sizes;
- });
+ tilingOptions.setTileSizeComputationFunction([&, index](OpBuilder &b,
+ Operation *) {
+ SmallVector<Value, 4> sizes;
+ sizes.reserve(tileSizes.size());
+ unsigned dynamicIdx = 0;
+ for (OpFoldResult ofr : getMixedSizes()) {
+ if (auto attr = ofr.dyn_cast<Attribute>()) {
+ sizes.push_back(b.create<arith::ConstantIndexOp>(
+ getLoc(), attr.cast<IntegerAttr>().getInt()));
+ continue;
+ }
+ ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
+ ArrayRef<int64_t> params = paramSizes[dynamicIdx];
+ ++dynamicIdx;
+ assert((dynamicSizes.empty() ^ params.empty()) &&
+ "expected either dynamic sizes or parameters");
+ if (!params.empty()) {
+ sizes.push_back(
+ b.create<arith::ConstantIndexOp>(getLoc(), params[index]));
+ } else {
+ sizes.push_back(dynamicSizes[index]->getResult(0));
+ }
+ }
+ return sizes;
+ });
}
tilingOptions.setInterchange(getInterchange());
@@ -2149,27 +2405,27 @@ DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
// Transform all targets one by one.
for (Operation *target : targets) {
- auto tilableOp = dyn_cast<TilingInterface>(target);
- if (!tilableOp) {
+ auto tileableOp = dyn_cast<TilingInterface>(target);
+ if (!tileableOp) {
DiagnosedSilenceableFailure diag =
transformOp.emitSilenceableError()
<< "only TilingInterface ops are supported";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
- rewriter.setInsertionPoint(tilableOp);
+ rewriter.setInsertionPoint(tileableOp);
FailureOr<linalg::ForeachThreadTilingResult> tilingResult = failure();
if (!mixedNumThreads.empty()) {
- tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp,
+ tilingResult = linalg::tileToForeachThreadOp(rewriter, tileableOp,
mixedNumThreads, mapping);
} else {
tilingResult = linalg::tileToForeachThreadOpUsingTileSizes(
- rewriter, tilableOp, mixedTileSizes, mapping);
+ rewriter, tileableOp, mixedTileSizes, mapping);
}
if (failed(tilingResult))
- return transformOp.emitDefaultSilenceableFailure(tilableOp);
- rewriter.replaceOp(tilableOp, tilingResult->tileOp->getResults());
+ return transformOp.emitDefaultSilenceableFailure(tileableOp);
+ rewriter.replaceOp(tileableOp, tilingResult->tileOp->getResults());
tileOps.push_back(tilingResult->tileOp);
tiledOps.push_back(tilingResult->tiledOp);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index d8c337d32e36c..4515d711f72bf 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3231,7 +3231,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return true;
}
return shape == constTileSize.value();
-
})) {
return op->emitError("mismatch in inner tile sizes specified and shaped of "
"tiled dimension in the packed type");
@@ -3239,6 +3238,57 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return success();
}
+namespace {
+/// Subset of PackOp/UnPackOp fields used to compute the result of applying
+/// various permutations to the op.
+// TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
+// these. These may or may not become true foldings / canonicalizations
+// depending on how aggressive we want to be in automatically folding
+// transposes.
+struct PackOrUnPackTransposeResult {
+ SmallVector<int64_t> innerDimsPos;
+ SmallVector<OpFoldResult> innerTiles;
+ SmallVector<int64_t> outerDimsPerm;
+};
+} // namespace
+
+template <typename OpTy>
+static PackOrUnPackTransposeResult
+commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
+ ArrayRef<int64_t> innerPermutation,
+ ArrayRef<int64_t> outerPermutation) {
+ static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
+ "applies to only pack or unpack operations");
+ assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
+ "some permutation must be non-empty");
+ PackOrUnPackTransposeResult metadata;
+ metadata.innerDimsPos =
+ SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
+ metadata.innerTiles =
+ SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
+ int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
+ ? packOrUnPackOp.getSourceRank()
+ : packOrUnPackOp.getDestRank();
+ metadata.outerDimsPerm =
+ packOrUnPackOp.getOuterDimsPerm().empty()
+ ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
+ : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
+ if (!innerPermutation.empty()) {
+ assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
+ isPermutationVector(innerPermutation) &&
+ "invalid inner permutation");
+ applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
+ applyPermutationToVector(metadata.innerTiles, innerPermutation);
+ }
+ if (!outerPermutation.empty()) {
+ assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
+ isPermutationVector(outerPermutation) &&
+ "invalid outer permutation");
+ applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
+ }
+ return metadata;
+}
+
//===----------------------------------------------------------------------===//
// PackOp
//===----------------------------------------------------------------------===//
@@ -3386,6 +3436,19 @@ Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
}
+PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
+ ArrayRef<int64_t> innerPermutation,
+ ArrayRef<int64_t> outerPermutation) {
+ PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
+ *this, innerPermutation, outerPermutation);
+ Value transposedDest =
+ createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
+ metadata.innerDimsPos, metadata.outerDimsPerm);
+ return b.create<PackOp>(loc, getSource(), transposedDest,
+ metadata.innerDimsPos, metadata.innerTiles,
+ getPaddingValue(), metadata.outerDimsPerm);
+}
+
/// Returns true if the tiles and the tiled dims are constant.
template <typename OpTy>
bool areTilesAndTiledDimsAllConstant(OpTy op) {
@@ -3508,6 +3571,17 @@ void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
builder.getDenseI64ArrayAttr(staticTileSizes));
}
+UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
+ Value transposedSource,
+ ArrayRef<int64_t> innerPermutation,
+ ArrayRef<int64_t> outerPermutation) {
+ PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
+ *this, innerPermutation, outerPermutation);
+ return b.create<UnPackOp>(loc, transposedSource, getDest(),
+ metadata.innerDimsPos, metadata.innerTiles,
+ metadata.outerDimsPerm);
+}
+
/// pack(unpack(x)) -> x
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
PatternRewriter &rewriter) {
diff --git a/mlir/test/Dialect/Linalg/transform-op-pack.mlir b/mlir/test/Dialect/Linalg/transform-op-pack.mlir
index d1304bb2be483..b8c569f0a3c17 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pack.mlir
@@ -49,21 +49,21 @@ transform.sequence failures(propagate) {
iterator_types = ["reduction", "parallel"]
}
-// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2) -> (d1)>
// CHECK-LABEL: @col_reduction_2d_static
// CHECK-SAME: %[[T0:.+]]: tensor<7x3xf16>,
// CHECK-SAME: %[[T1:.+]]: tensor<3xf16>
func.func @col_reduction_2d_static(%t0: tensor<7x3xf16>, %t1: tensor<3xf16>) -> tensor<3xf16> {
- // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x4xf16>
+ // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<3x2x4xf16>
// CHECK: %[[PACKED:.*]] = tensor.pack %[[T0]] padding_value(%{{.*}} : f16)
- // CHECK-SAME: inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<7x3xf16> -> tensor<2x3x4xf16>
+ // CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<7x3xf16> -> tensor<3x2x4xf16>
// CHECK-NOT: tensor.pack
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]]]
// CHECK-SAME: iterator_types = ["reduction", "parallel", "reduction"]
- // CHECK-SAME: ins(%{{.*}} : tensor<2x3x4xf16>)
+ // CHECK-SAME: ins(%{{.*}} : tensor<3x2x4xf16>)
// CHECK-SAME: outs(%{{.*}} : tensor<3xf16>)
%2 = linalg.generic #col_reduction_2d_trait ins(%t0 : tensor<7x3xf16>) outs(%t1 : tensor<3xf16>) {
^bb0(%in: f16, %out: f16):
@@ -78,8 +78,15 @@ func.func @col_reduction_2d_static(%t0: tensor<7x3xf16>, %t1: tensor<3xf16>) ->
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
- transform.structured.pack %0 packed_sizes = [4, 0]
+ %1 = transform.structured.pack %0 packed_sizes = [4, 0]
: (!pdl.operation) -> (!transform.op<"linalg.generic">)
+ %pack = transform.get_producer_of_operand %1[0]
+ : (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.pack">)
+ %2, %pack_2, %empty_unpack_2 =
+ transform.structured.pack_transpose %pack with_compute_op(%1)
+ outer_perm = [1, 0]
+ : (!transform.op<"tensor.pack">, !transform.op<"linalg.generic">)
+ -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !pdl.operation)
}
// -----
@@ -183,7 +190,7 @@ transform.sequence failures(propagate) {
// K N n k
// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
// M N m n
-// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d4, d3)>
// CHECK-LABEL: @matmul
// CHECK-SAME: %[[A:[0-9a-zA-Z]+]]: tensor<?x?xf32>,
@@ -196,19 +203,19 @@ func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
// CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?x2x4xf32>
// CHECK: %[[PACK_B:.*]] = tensor.pack %{{.*}} inner_dims_pos = [1, 0] inner_tiles = [3, 4]
// CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?x3x4xf32>
- // CHECK: %[[PACK_C:.*]] = tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 3]
- // CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?x2x3xf32>
+ // CHECK: %[[PACK_C:.*]] = tensor.pack %{{.*}} outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [3, 2]
+ // CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?x3x2xf32>
// CHECK: linalg.generic {indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]], #[[$PACKED_MAP_2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
// CHECK-SAME: ins(%{{.*}} : tensor<?x?x2x4xf32>, tensor<?x?x3x4xf32>)
- // CHECK-SAME: outs(%{{.*}} : tensor<?x?x2x3xf32>)
+ // CHECK-SAME: outs(%{{.*}} : tensor<?x?x3x2xf32>)
%0 = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
outs(%C: tensor<?x?xf32>)
-> tensor<?x?xf32>
- // CHECK: tensor.unpack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 3]
- // CHECK-SAME: : tensor<?x?x2x3xf32> -> tensor<?x?xf32>
+ // CHECK: tensor.unpack %{{.*}} outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [3, 2]
+ // CHECK-SAME: : tensor<?x?x3x2xf32> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
@@ -218,6 +225,14 @@ transform.sequence failures(propagate) {
// M N K
%1 = transform.structured.pack %0 packed_sizes = [2, 3, 4]
: (!pdl.operation) -> (!transform.op<"linalg.generic">)
+
+ %unpack = transform.get_consumers_of_result %1[0]
+ : (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.unpack">)
+ %2, %pack_2, %unpack_2 =
+ transform.structured.pack_transpose %unpack with_compute_op(%1)
+ outer_perm = [1, 0] inner_perm = [1, 0]
+ : (!transform.op<"tensor.unpack">, !transform.op<"linalg.generic">)
+ -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.op<"tensor.unpack">)
}
// -----
@@ -404,3 +419,177 @@ transform.sequence failures(propagate) {
%1 = transform.structured.pack %0 packed_sizes = [2, 3]
: (!pdl.operation) -> (!transform.op<"linalg.generic">)
}
+
+// -----
+
+func.func @no_single_packing_op(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) {
+ %0 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
+ %1 = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %source : tensor<4x16x32x16xf32> -> tensor<128x256xf32>
+ %2 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
+ return
+}
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.pack"]} in %arg1
+ %1 = transform.structured.match ops{["tensor.unpack"]} in %arg1
+ // expected-error @below {{requires target to map to exactly 1 packing op and 1 packed op (got 2 and 1)}}
+ transform.structured.pack_transpose %0 with_compute_op(%1)
+ inner_perm = [0]
+ : (!pdl.operation, !pdl.operation)
+ -> (!pdl.operation, !pdl.operation, !pdl.operation)
+}
+
+// -----
+
+func.func @no_single_pack_unpack(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) {
+ %0 = arith.constant 0 : index
+ %1 = tensor.empty() : tensor<f32>
+ return
+}
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["arith.constant"]} in %arg1
+ %1 = transform.structured.match ops{["tensor.empty"]} in %arg1
+ // expected-error @below {{requires target to map to a tensor.pack or tensor.unpack}}
+ transform.structured.pack_transpose %0 with_compute_op(%1)
+ inner_perm = [0]
+ : (!pdl.operation, !pdl.operation)
+ -> (!pdl.operation, !pdl.operation, !pdl.operation)
+}
+
+// -----
+
+func.func @no_linalg_target(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) {
+ %0 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
+ %1 = arith.constant 0 : index
+ return
+}
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.pack"]} in %arg1
+ %1 = transform.structured.match ops{["arith.constant"]} in %arg1
+ // expected-error @below {{requires a LinalgOp target}}
+ transform.structured.pack_transpose %0 with_compute_op(%1)
+ inner_perm = [0]
+ : (!pdl.operation, !pdl.operation)
+ -> (!pdl.operation, !pdl.operation, !pdl.operation)
+}
+
+// -----
+
+func.func @no_single_use_by_linalg(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) {
+ %0 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
+ %f0 = arith.constant 0.0 : f32
+ %1 = tensor.empty() : tensor<f32>
+ %2 = linalg.fill ins(%f0: f32) outs(%1 : tensor<f32>) -> tensor<f32>
+ return
+}
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.pack"]} in %arg1
+ %1 = transform.structured.match ops{["linalg.fill"]} in %arg1
+ // expected-error @below {{not a single use by the LinalgOp target}}
+ transform.structured.pack_transpose %0 with_compute_op(%1)
+ inner_perm = [0]
+ : (!pdl.operation, !pdl.operation)
+ -> (!pdl.operation, !pdl.operation, !pdl.operation)
+}
+
+// -----
+
+func.func @not_produced_by_linalg(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) {
+ %a = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
+ %b = tensor.unpack %a inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %source : tensor<4x16x32x16xf32> -> tensor<128x256xf32>
+ %f0 = arith.constant 0.0 : f32
+ %1 = tensor.empty() : tensor<f32>
+ %2 = linalg.fill ins(%f0: f32) outs(%1 : tensor<f32>) -> tensor<f32>
+ return
+}
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1
+ %1 = transform.structured.match ops{["linalg.fill"]} in %arg1
+ // expected-error @below {{not produced by the LinalgOp target}}
+ transform.structured.pack_transpose %0 with_compute_op(%1)
+ inner_perm = [0]
+ : (!pdl.operation, !pdl.operation)
+ -> (!pdl.operation, !pdl.operation, !pdl.operation)
+}
+
+// -----
+
+func.func @no_matching_pack(%source: tensor<16xf32>) {
+ %f0 = arith.constant 0.0 : f32
+ %1 = tensor.empty() : tensor<4x4xf32>
+ %2 = linalg.fill ins(%f0: f32) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32>
+ %b = tensor.unpack %2 inner_dims_pos = [0] inner_tiles = [4] into %source : tensor<4x4xf32> -> tensor<16xf32>
+ return
+}
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1
+ %1 = transform.structured.match ops{["linalg.fill"]} in %arg1
+ // expected-error @below {{could not find matching pack op}}
+ transform.structured.pack_transpose %0 with_compute_op(%1)
+ inner_perm = [0]
+ : (!pdl.operation, !pdl.operation)
+ -> (!pdl.operation, !pdl.operation, !pdl.operation)
+}
+
+// -----
+
+func.func @invalid_outer_perm(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
+ -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>)
+ -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ %1 = transform.structured.pack %0 packed_sizes = [2, 3, 4]
+ : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+
+ %unpack = transform.get_consumers_of_result %1[0]
+ : (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.unpack">)
+ %2, %pack_2, %unpack_2 =
+ // expected-error @below {{invalid outer_perm}}
+ transform.structured.pack_transpose %unpack with_compute_op(%1)
+ outer_perm = [1]
+ : (!transform.op<"tensor.unpack">, !transform.op<"linalg.generic">)
+ -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.op<"tensor.unpack">)
+}
+
+// -----
+
+func.func @invalid_inner_perm(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
+ -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>)
+ -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ %1 = transform.structured.pack %0 packed_sizes = [2, 3, 4]
+ : (!pdl.operation) -> (!transform.op<"linalg.generic">)
+
+ %unpack = transform.get_consumers_of_result %1[0]
+ : (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.unpack">)
+ %2, %pack_2, %unpack_2 =
+ // expected-error @below {{invalid inner_perm}}
+ transform.structured.pack_transpose %unpack with_compute_op(%1)
+ inner_perm = [1]
+ : (!transform.op<"tensor.unpack">, !transform.op<"linalg.generic">)
+ -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.op<"tensor.unpack">)
+}
More information about the Mlir-commits
mailing list