[Mlir-commits] [mlir] d261aa8 - [mlir] Add TransposeOp to Linalg structured ops.
Oleg Shyshkov
llvmlistbot at llvm.org
Wed Oct 19 03:28:11 PDT 2022
Author: Oleg Shyshkov
Date: 2022-10-19T12:27:52+02:00
New Revision: d261aa88f89332fbe9b1ee688ed5b75ae7414aff
URL: https://github.com/llvm/llvm-project/commit/d261aa88f89332fbe9b1ee688ed5b75ae7414aff
DIFF: https://github.com/llvm/llvm-project/commit/d261aa88f89332fbe9b1ee688ed5b75ae7414aff.diff
LOG: [mlir] Add TransposeOp to Linalg structured ops.
RFC: https://discourse.llvm.org/t/rfc-primitive-ops-add-mapop-reductionop-transposeop-broadcastop-to-linalg/64184
Differential Revision: https://reviews.llvm.org/D135854
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index 28c75fcfa6530..e231bddfcc414 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -70,6 +70,10 @@ AffineMap extractOrIdentityMap(Optional<AffineMap> maybeMap, unsigned rank,
SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
ArrayRef<AffineExpr> b);
+/// Check if `permutation` is a permutation of the range
+/// `[0, permutation.size())`.
+bool isPermutation(ArrayRef<int64_t> permutation);
} // namespace linalg
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 4b83de12a4108..9c2246e74dcac 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -360,6 +360,78 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
+// Transpose op.
+def TransposeOp : LinalgStructuredBase_Op<"transpose", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ SameVariadicOperandSize,
+ SingleBlockImplicitTerminator<"YieldOp">]> {
+ let summary = "Transpose operator";
+ let description = [{
+ Permutes the dimensions of `input` according to the given `permutation`.
+ `dim(result, i) = dim(input, permutation[i])`
+ This op actually moves data, unlike `memref.transpose` which is a metadata
+ operation only that produces a transposed "view".
+ Example:
+ ```
+ %transpose = linalg.transpose
+ ins(%input:tensor<16x64xf32>)
+ outs(%init:tensor<64x16xf32>)
+ permutation = [1, 0]
+ ```
+ }];
+ let arguments = (ins
+ // Input arg
+ TensorOrMemref:$input,
+ // Output arg
+ TensorOrMemref:$init,
+ DenseI64ArrayAttr:$permutation
+ );
+ let results = (outs Variadic<AnyTensor>:$result);
+ let regions = (region SizedRegion<1>:$region);
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins "Value":$input, "Value":$init,
+ "DenseI64ArrayAttr":$permutation, CArg<"ArrayRef<NamedAttribute>",
+ "{}">:$attributes)>,
+ OpBuilder<(ins "Value":$input, "Value":$init,
+ "ArrayRef<int64_t>":$permutation, CArg<"ArrayRef<NamedAttribute>",
+ "{}">:$attributes)>,
+ ];
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ // Declare functions necessary for LinalgStructuredInterface.
+ SmallVector<StringRef> getIteratorTypesArray();
+ ArrayAttr getIndexingMaps();
+ std::string getLibraryCallName() {
+ return "op_has_no_registered_library_name";
+ }
+ // Implement functions necessary for DestinationStyleOpInterface.
+ std::pair<int64_t, int64_t> getOutputsPositionRange() {
+ int64_t getNumOperands = this->getNumOperands();
+ return {getNumOperands - 1, getNumOperands};
+ }
+ static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+ mlir::ArrayRef<mlir::NamedAttribute>)>
+ getRegionBuilder();
+ static void createRegion(::mlir::OpBuilder &opBuilder,
+ ::mlir::OperationState & odsState);
+ }];
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
// Named Linalg ops, implemented as a declarative configurations of generic ops.
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 305b859ac13d1..6a10d4332e7eb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -41,10 +41,6 @@ bool hasOnlyScalarElementwiseOp(Region &r);
/// Check if a LinalgOp is an element-wise operation.
bool isElementwise(LinalgOp op);
-/// Check if `permutation` is a permutation of the range
-/// `[0, permutation.size())`.
-bool isPermutation(ArrayRef<int64_t> permutation);
/// Check if iterator type has "parallel" semantics.
bool isParallelIterator(StringRef iteratorType);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 2fcd21cb59f99..82e5024cf58bf 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1601,6 +1601,142 @@ LogicalResult ReduceOp::verify() {
return success();
+// TransposeOp
+std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+ mlir::ArrayRef<mlir::NamedAttribute>)>
+TransposeOp::getRegionBuilder() {
+ return [](mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
+ mlir::ArrayRef<mlir::NamedAttribute>) {
+ b.create<linalg::YieldOp>(block.getArguments().back());
+ };
+void TransposeOp::createRegion(::mlir::OpBuilder &opBuilder,
+ ::mlir::OperationState &odsState) {
+ Region *region = odsState.addRegion();
+ SmallVector<Type> argTypes;
+ SmallVector<Location> argLocs;
+ for (auto t : odsState.operands) {
+ argTypes.push_back(getElementTypeOrSelf(t));
+ argLocs.push_back(opBuilder.getUnknownLoc());
+ }
+ // RAII.
+ OpBuilder::InsertionGuard guard(opBuilder);
+ Block *body =
+ opBuilder.createBlock(region, /*insertPt=*/{}, argTypes, argLocs);
+ ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
+ getRegionBuilder()(b, *body, odsState.attributes.getAttrs());
+void TransposeOp::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState, Value input,
+ Value init, DenseI64ArrayAttr permutation,
+ ArrayRef<NamedAttribute> attributes) {
+ odsState.addOperands(input);
+ odsState.addOperands(init);
+ odsState.addAttribute(getPermutationAttrName(odsState.name), permutation);
+ odsState.addAttributes(attributes);
+ odsState.addTypes(init.getType());
+ createRegion(odsBuilder, odsState);
+void TransposeOp::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState, Value input,
+ Value init, ArrayRef<int64_t> permutation,
+ ArrayRef<NamedAttribute> attributes) {
+ build(odsBuilder, odsState, input, init,
+ odsBuilder.getDenseI64ArrayAttr(permutation), attributes);
+ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
+ if (failed(parseDstStyleOp(
+ parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
+ return parseDenseI64ArrayAttr(parser, attributes, "permutation");
+ })))
+ return failure();
+ OpBuilder opBuilder(parser.getContext());
+ createRegion(opBuilder, result);
+ return success();
+void TransposeOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ if (!getResults().empty())
+ setNameFn(getResults().front(), "transposed");
+void TransposeOp::print(OpAsmPrinter &p) {
+ printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
+ SmallVector<Value>(getOutputOperands()));
+ printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
+ p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
+LogicalResult TransposeOp::verify() {
+ ArrayRef<int64_t> permutationRef = getPermutation();
+ if (!isPermutation(permutationRef))
+ return emitOpError("permutation is not valid");
+ auto inputType = getInput().getType();
+ auto initType = getInit().getType();
+ int64_t rank = inputType.getRank();
+ if (rank != initType.getRank())
+ return emitOpError() << "input rank " << rank
+ << " does not match init rank " << initType.getRank();
+ if (rank != static_cast<int64_t>(permutationRef.size()))
+ return emitOpError() << "size of permutation " << permutationRef.size()
+ << " does not match the argument rank " << rank;
+ auto inputDims = inputType.getShape();
+ auto initDims = initType.getShape();
+ for (int64_t i = 0; i < rank; ++i) {
+ int64_t inputDim = inputDims[permutationRef[i]];
+ int64_t initDim = initDims[i];
+ if (inputDim != initDim) {
+ return emitOpError() << "dim(result, " << i << ") = " << initDim
+ << " doesn't match dim(input, permutation[" << i
+ << "]) = " << inputDim;
+ }
+ }
+ return success();
+SmallVector<StringRef> TransposeOp::getIteratorTypesArray() {
+ int64_t rank = getInit().getType().getRank();
+ return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
+ArrayAttr TransposeOp::getIndexingMaps() {
+ Builder builder(getContext());
+ int64_t rank = getInit().getType().getRank();
+ return builder.getAffineMapArrayAttr(
+ {builder.getMultiDimIdentityMap(rank),
+ AffineMap::getPermutationMap(
+ llvm::to_vector_of<unsigned>(getPermutation()), getContext())});
+void TransposeOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ getGenericEffectsImpl(effects, getOperation()->getResults(),
+ getInputOperands(), getOutputOperands());
// YieldOp
@@ -1710,6 +1846,19 @@ SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
return llvm::to_vector<4>(concatRanges);
+bool mlir::linalg::isPermutation(ArrayRef<int64_t> permutation) {
+ // Count the number of appearances for all indices.
+ SmallVector<int64_t> indexCounts(permutation.size(), 0);
+ for (auto index : permutation) {
+ // Exit if the index is out-of-range.
+ if (index < 0 || index >= static_cast<int64_t>(permutation.size()))
+ return false;
+ ++indexCounts[index];
+ }
+ // Return true if all indices appear once.
+ return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
if (auto memref = t.dyn_cast<MemRefType>()) {
ss << "view";
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index aba2d5f5cd49f..af5a2012429ba 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -186,19 +186,6 @@ bool isElementwise(LinalgOp op) {
return hasOnlyScalarElementwiseOp(op->getRegion(0));
-bool isPermutation(ArrayRef<int64_t> permutation) {
- // Count the number of appearances for all indices.
- SmallVector<int64_t> indexCounts(permutation.size(), 0);
- for (auto index : permutation) {
- // Exit if the index is out-of-range.
- if (index < 0 || index >= static_cast<int64_t>(permutation.size()))
- return false;
- indexCounts[index]++;
- }
- // Return true if all indices appear once.
- return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
bool isParallelIterator(StringRef iteratorType) {
return iteratorType == getParallelIteratorTypeName();
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 00352c43bb07e..e6ab837141f1f 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -624,3 +624,52 @@ func.func @reduce_
diff erent_output_shapes(%input1: tensor<16x32x64xf32>,
func.return %reduce, %reduce2 : tensor<16x64xf32>, tensor<17x64xf32>
+// -----
+func.func @transpose_invalid_permutation(%input: tensor<16x32x64xf32>,
+ %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
+ // expected-error @+1 {{'linalg.transpose' op permutation is not valid}}
+ %transpose = linalg.transpose
+ ins(%input:tensor<16x32x64xf32>)
+ outs(%init:tensor<32x64x16xf32>)
+ permutation = [1, 1, 2]
+ func.return %transpose : tensor<32x64x16xf32>
+// -----
+func.func @transpose_permutated_dims_mismatch(%input: tensor<16x32x64xf32>,
+ %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
+ // expected-error @+1 {{'linalg.transpose' op dim(result, 0) = 32 doesn't match dim(input, permutation[0]) = 16}}
+ %transpose = linalg.transpose
+ ins(%input:tensor<16x32x64xf32>)
+ outs(%init:tensor<32x64x16xf32>)
+ permutation = [0, 1, 2]
+ func.return %transpose : tensor<32x64x16xf32>
+// -----
+func.func @transpose_rank_permutation_size_mismatch(
+ %input: tensor<16x32x64xf32>,
+ %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
+ // expected-error @+1 {{'linalg.transpose' op size of permutation 2 does not match the argument rank 3}}
+ %transpose = linalg.transpose
+ ins(%input:tensor<16x32x64xf32>)
+ outs(%init:tensor<32x64x16xf32>)
+ permutation = [1, 0]
+ func.return %transpose : tensor<32x64x16xf32>
+// -----
+func.func @transpose_input_init_rank_mismatch(%input: tensor<16x32xf32>,
+ %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
+ // expected-error @+1 {{'linalg.transpose' op input rank 2 does not match init rank 3}}
+ %transpose = linalg.transpose
+ ins(%input:tensor<16x32xf32>)
+ outs(%init:tensor<32x64x16xf32>)
+ permutation = [1, 0, 2]
+ func.return %transpose : tensor<32x64x16xf32>
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index f751ddff7df0e..4bea3f6d38376 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -67,11 +67,11 @@ func.func @fill_view(%arg0: memref<?xf32, strided<[1], offset: ?>>, %arg1: f32)
// -----
-func.func @transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
+func.func @memref_transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
%0 = memref.transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
-// CHECK-LABEL: func @transpose
+// CHECK-LABEL: func @memref_transpose
// CHECK: memref.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) :
// CHECK-SAME: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
@@ -457,3 +457,27 @@ func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>,
// CHECK-LABEL: func @variadic_reduce_memref
// CHECK: linalg.reduce
+// -----
+func.func @transpose(%input: tensor<16x32x64xf32>,
+ %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
+ %transpose = linalg.transpose
+ ins(%input:tensor<16x32x64xf32>)
+ outs(%init:tensor<32x64x16xf32>)
+ permutation = [1, 2, 0]
+ func.return %transpose : tensor<32x64x16xf32>
+// CHECK-LABEL: func @transpose
+// -----
+func.func @transpose_memref(%input: memref<16x32x64xf32>,
+ %init: memref<32x64x16xf32>) {
+ linalg.transpose
+ ins(%input:memref<16x32x64xf32>)
+ outs(%init:memref<32x64x16xf32>)
+ permutation = [1, 2, 0]
+ func.return
+// CHECK-LABEL: func @transpose_memref
More information about the Mlir-commits
mailing list