[Mlir-commits] [mlir] [mlir][memref] Transpose: allow affine map layouts in result, extend folder (PR #76294)
Felix Schneider
llvmlistbot at llvm.org
Sat Dec 23 09:17:53 PST 2023
https://github.com/ubfx created https://github.com/llvm/llvm-project/pull/76294
Currently, the `memref.transpose` verifier forces the result type of the Op to have an explicit `StridedLayoutAttr` via the method `inferTransposeResultType`. This means that things like the example Op given in the documentation (https://mlir.llvm.org/docs/Dialects/MemRef/#memreftranspose-memreftransposeop) is actually invalid because it uses an `AffineMap` to specify the layout:
```mlir
%1 = memref.transpose %0 (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
```
It also means that we can't "un-transpose" a transposed memref back to the implicit layout form, because the verifier will always enforce the explicit strided layout.
This patch makes the following changes:
1. `inferTransposeResultType()` returns a `MemRefType` with canonicalized strided layout, i.e the strides are turned into a linearizing affine expression.
2. The verifier checks whether the canonicalized strided layout of the result Type is identitcal to the infered (also canonical) result type layout. This way, it's only important that the two Types have the same strided layout, not necessarily the same representation of it.
3. The folder is extended to support folding away the trivial case of identity permutation and to fold one transposition into another by composing the permutation maps.
>From d11685038d02ef680427092a8d2f2c821312fa20 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 23 Dec 2023 18:02:40 +0100
Subject: [PATCH] [mlir][memref] Transpose: allow affine map layouts in result,
extend folder
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 48 ++++++++++++++--------
mlir/test/Dialect/MemRef/canonicalize.mlir | 23 +++++++++++
mlir/test/Dialect/MemRef/invalid.mlir | 2 +-
mlir/test/Dialect/MemRef/ops.mlir | 6 +++
4 files changed, 60 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a332fe253ba645..8d7cb6e1cc92cc 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3148,7 +3148,7 @@ void TransposeOp::getAsmResultNames(
setNameFn(getResult(), "transpose");
}
-/// Build a strided memref type by applying `permutationMap` tp `memRefType`.
+/// Build a strided memref type by applying `permutationMap` to `memRefType`.
static MemRefType inferTransposeResultType(MemRefType memRefType,
AffineMap permutationMap) {
auto rank = memRefType.getRank();
@@ -3157,18 +3157,14 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
assert(originalStrides.size() == static_cast<unsigned>(rank));
// Compute permuted sizes and strides.
- SmallVector<int64_t> sizes(rank, 0);
- SmallVector<int64_t> strides(rank, 1);
- for (const auto &en : llvm::enumerate(permutationMap.getResults())) {
- unsigned position = cast<AffineDimExpr>(en.value()).getPosition();
- sizes[en.index()] = originalSizes[position];
- strides[en.index()] = originalStrides[position];
- }
+ auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
+ auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
- return MemRefType::Builder(memRefType)
- .setShape(sizes)
- .setLayout(
- StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
+ auto stridedTy = MemRefType::Builder(memRefType)
+ .setShape(sizes)
+ .setLayout(StridedLayoutAttr::get(
+ memRefType.getContext(), offset, strides));
+ return canonicalizeStridedLayout(stridedTy);
}
void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
@@ -3216,18 +3212,34 @@ LogicalResult TransposeOp::verify() {
return emitOpError("expected a permutation map of same rank as the input");
auto srcType = llvm::cast<MemRefType>(getIn().getType());
- auto dstType = llvm::cast<MemRefType>(getType());
- auto transposedType = inferTransposeResultType(srcType, getPermutation());
- if (dstType != transposedType)
- return emitOpError("output type ")
- << dstType << " does not match transposed input type " << srcType
- << ", " << transposedType;
+ auto canonicalDstType =
+ canonicalizeStridedLayout(llvm::cast<MemRefType>(getType()));
+ auto inferedDstType = inferTransposeResultType(srcType, getPermutation());
+
+ if (canonicalDstType != inferedDstType)
+ return emitOpError("canonicalized output type ")
+ << canonicalDstType
+ << " does not match canonical transposed input type " << srcType
+ << ", " << inferedDstType;
return success();
}
OpFoldResult TransposeOp::fold(FoldAdaptor) {
+ // First check for identity permutation, we can fold it away if input and
+ // result types are identical already.
+ if (getPermutation().isIdentity() && getType() == getIn().getType())
+ return getIn();
if (succeeded(foldMemRefCast(*this)))
return getResult();
+ // Fold two consecutive memref.transpose Ops into one by composing their
+ // permutation maps.
+ if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
+ AffineMap composedPermutation =
+ otherTransposeOp.getPermutation().compose(getPermutation());
+ getInMutable().assign(otherTransposeOp.getIn());
+ setPermutation(composedPermutation);
+ return getResult();
+ }
return {};
}
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index d3406c630f6dd7..3471a1f912e7ea 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -988,3 +988,26 @@ func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index)
// CHECK: return %[[cast]]
return %0 : memref<?x?xf32, strided<[384, 1], offset: ?>>
}
+
+// -----
+
+// CHECK-LABEL: func @fold_double_transpose(
+// CHECK-SAME: %[[arg0:.*]]: memref<1x2x3x4x5xf32>
+func.func @fold_double_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> {
+ // CHECK: %[[ONETRANSPOSE:.+]] = memref.transpose %[[arg0]] (d0, d1, d2, d3, d4) -> (d4, d2, d1, d3, d0)
+ %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>>
+ %1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
+ // CHECK: return %[[ONETRANSPOSE]]
+ return %1 : memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_identity_transpose(
+// CHECK-SAME: %[[arg0:.*]]: memref<1x2x3x4x5xf32>
+func.func @fold_identity_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<1x2x3x4x5xf32> {
+ %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>>
+ %1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d0, d1, d2, d3, d4) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<1x2x3x4x5xf32>
+ // CHECK: return %[[arg0]]
+ return %1 : memref<1x2x3x4x5xf32>
+}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index f9b870f77266e1..25e08eda8f4dac 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -142,7 +142,7 @@ func.func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(o
// -----
func.func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
- // expected-error @+1 {{output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>'}}
+ // expected-error @+1 {{canonicalized output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match canonical transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>', 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 + s0 + d1 * s1)>>'}}
memref.transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 7e2018ca58dc4a..a7730b71a0eacf 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -378,3 +378,9 @@ func.func @memref_memory_space_cast(%src : memref<?xf32>) -> memref<?xf32, 1> {
%dst = memref.memory_space_cast %src : memref<?xf32> to memref<?xf32, 1>
return %dst : memref<?xf32, 1>
}
+
+// CHECK-LABEL: func @memref_transpose_map
+func.func @memref_transpose_map(%src : memref<?x?xf32>) -> memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>> {
+ %dst = memref.transpose %src (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
+ return %dst : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list