[Mlir-commits] [mlir] [mlir][memref] Transpose: allow affine map layouts in result, extend folder (PR #76294)
Felix Schneider
llvmlistbot at llvm.org
Wed Jan 10 11:21:11 PST 2024
https://github.com/ubfx updated https://github.com/llvm/llvm-project/pull/76294
>From 7a5c1a5c0d63917ba14a5b6dbba7080373939f26 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 23 Dec 2023 18:02:40 +0100
Subject: [PATCH 1/4] [mlir][memref] Transpose: allow affine map layouts in
result, extend folder
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 48 ++++++++++++++--------
mlir/test/Dialect/MemRef/canonicalize.mlir | 23 +++++++++++
mlir/test/Dialect/MemRef/invalid.mlir | 2 +-
mlir/test/Dialect/MemRef/ops.mlir | 6 +++
4 files changed, 60 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a332fe253ba645..8d7cb6e1cc92cc 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3148,7 +3148,7 @@ void TransposeOp::getAsmResultNames(
setNameFn(getResult(), "transpose");
}
-/// Build a strided memref type by applying `permutationMap` tp `memRefType`.
+/// Build a strided memref type by applying `permutationMap` to `memRefType`.
static MemRefType inferTransposeResultType(MemRefType memRefType,
AffineMap permutationMap) {
auto rank = memRefType.getRank();
@@ -3157,18 +3157,14 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
assert(originalStrides.size() == static_cast<unsigned>(rank));
// Compute permuted sizes and strides.
- SmallVector<int64_t> sizes(rank, 0);
- SmallVector<int64_t> strides(rank, 1);
- for (const auto &en : llvm::enumerate(permutationMap.getResults())) {
- unsigned position = cast<AffineDimExpr>(en.value()).getPosition();
- sizes[en.index()] = originalSizes[position];
- strides[en.index()] = originalStrides[position];
- }
+ auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
+ auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
- return MemRefType::Builder(memRefType)
- .setShape(sizes)
- .setLayout(
- StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
+ auto stridedTy = MemRefType::Builder(memRefType)
+ .setShape(sizes)
+ .setLayout(StridedLayoutAttr::get(
+ memRefType.getContext(), offset, strides));
+ return canonicalizeStridedLayout(stridedTy);
}
void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
@@ -3216,18 +3212,34 @@ LogicalResult TransposeOp::verify() {
return emitOpError("expected a permutation map of same rank as the input");
auto srcType = llvm::cast<MemRefType>(getIn().getType());
- auto dstType = llvm::cast<MemRefType>(getType());
- auto transposedType = inferTransposeResultType(srcType, getPermutation());
- if (dstType != transposedType)
- return emitOpError("output type ")
- << dstType << " does not match transposed input type " << srcType
- << ", " << transposedType;
+ auto canonicalDstType =
+ canonicalizeStridedLayout(llvm::cast<MemRefType>(getType()));
+ auto inferedDstType = inferTransposeResultType(srcType, getPermutation());
+
+ if (canonicalDstType != inferedDstType)
+ return emitOpError("canonicalized output type ")
+ << canonicalDstType
+ << " does not match canonical transposed input type " << srcType
+ << ", " << inferedDstType;
return success();
}
OpFoldResult TransposeOp::fold(FoldAdaptor) {
+ // First check for identity permutation, we can fold it away if input and
+ // result types are identical already.
+ if (getPermutation().isIdentity() && getType() == getIn().getType())
+ return getIn();
if (succeeded(foldMemRefCast(*this)))
return getResult();
+ // Fold two consecutive memref.transpose Ops into one by composing their
+ // permutation maps.
+ if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
+ AffineMap composedPermutation =
+ otherTransposeOp.getPermutation().compose(getPermutation());
+ getInMutable().assign(otherTransposeOp.getIn());
+ setPermutation(composedPermutation);
+ return getResult();
+ }
return {};
}
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index d3406c630f6dd7..3471a1f912e7ea 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -988,3 +988,26 @@ func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index)
// CHECK: return %[[cast]]
return %0 : memref<?x?xf32, strided<[384, 1], offset: ?>>
}
+
+// -----
+
+// CHECK-LABEL: func @fold_double_transpose(
+// CHECK-SAME: %[[arg0:.*]]: memref<1x2x3x4x5xf32>
+func.func @fold_double_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> {
+ // CHECK: %[[ONETRANSPOSE:.+]] = memref.transpose %[[arg0]] (d0, d1, d2, d3, d4) -> (d4, d2, d1, d3, d0)
+ %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>>
+ %1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
+ // CHECK: return %[[ONETRANSPOSE]]
+ return %1 : memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_identity_transpose(
+// CHECK-SAME: %[[arg0:.*]]: memref<1x2x3x4x5xf32>
+func.func @fold_identity_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<1x2x3x4x5xf32> {
+ %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>>
+ %1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d0, d1, d2, d3, d4) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<1x2x3x4x5xf32>
+ // CHECK: return %[[arg0]]
+ return %1 : memref<1x2x3x4x5xf32>
+}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index f9b870f77266e1..25e08eda8f4dac 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -142,7 +142,7 @@ func.func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(o
// -----
func.func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
- // expected-error @+1 {{output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>'}}
+ // expected-error @+1 {{canonicalized output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match canonical transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>', 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 + s0 + d1 * s1)>>'}}
memref.transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 7e2018ca58dc4a..a7730b71a0eacf 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -378,3 +378,9 @@ func.func @memref_memory_space_cast(%src : memref<?xf32>) -> memref<?xf32, 1> {
%dst = memref.memory_space_cast %src : memref<?xf32> to memref<?xf32, 1>
return %dst : memref<?xf32, 1>
}
+
+// CHECK-LABEL: func @memref_transpose_map
+func.func @memref_transpose_map(%src : memref<?x?xf32>) -> memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>> {
+ %dst = memref.transpose %src (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
+ return %dst : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
+}
\ No newline at end of file
>From 80d87af6f0450c50f5dffd5968295c4053a742c0 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 23 Dec 2023 18:18:00 +0100
Subject: [PATCH 2/4] newline
---
mlir/test/Dialect/MemRef/ops.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index a7730b71a0eacf..2d69904f27db5e 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -383,4 +383,4 @@ func.func @memref_memory_space_cast(%src : memref<?xf32>) -> memref<?xf32, 1> {
func.func @memref_transpose_map(%src : memref<?x?xf32>) -> memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>> {
%dst = memref.transpose %src (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
return %dst : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
-}
\ No newline at end of file
+}
>From 0182df5a8c8c41374ffceba16da83db594156a1d Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 23 Dec 2023 19:38:25 +0100
Subject: [PATCH 3/4] compose the right way around and add a test that would
have caught the issue
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 2 +-
mlir/test/Dialect/MemRef/canonicalize.mlir | 12 ++++++++++++
2 files changed, 13 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 8d7cb6e1cc92cc..8b9fd34607219a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3235,7 +3235,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
// permutation maps.
if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
AffineMap composedPermutation =
- otherTransposeOp.getPermutation().compose(getPermutation());
+ getPermutation().compose(otherTransposeOp.getPermutation());
getInMutable().assign(otherTransposeOp.getIn());
setPermutation(composedPermutation);
return getResult();
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 3471a1f912e7ea..eccfc485b2034e 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1003,6 +1003,18 @@ func.func @fold_double_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4
// -----
+// CHECK-LABEL: func @fold_double_transpose2(
+// CHECK-SAME: %[[arg0:.*]]: memref<1x2x3x4x5xf32>
+func.func @fold_double_transpose2(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> {
+ // CHECK: %[[ONETRANSPOSE:.+]] = memref.transpose %[[arg0]] (d0, d1, d2, d3, d4) -> (d4, d2, d1, d3, d0)
+ %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d0, d1, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<1x2x5x4x3xf32, strided<[120, 60, 1, 5, 20]>>
+ %1 = memref.transpose %0 (d0, d1, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<1x2x5x4x3xf32, strided<[120, 60, 1, 5, 20]>> to memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
+ // CHECK: return %[[ONETRANSPOSE]]
+ return %1 : memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_identity_transpose(
// CHECK-SAME: %[[arg0:.*]]: memref<1x2x3x4x5xf32>
func.func @fold_identity_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<1x2x3x4x5xf32> {
>From 5067ba1c9797ae2704d2f1675f5fadeaf2cc1fa7 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Wed, 10 Jan 2024 20:20:18 +0100
Subject: [PATCH 4/4] Infer un-canonicalized strided layout
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 27 ++++++++++++------------
mlir/test/Dialect/MemRef/invalid.mlir | 2 +-
2 files changed, 14 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 8b9fd34607219a..b6c53838e615fd 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3160,11 +3160,10 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
- auto stridedTy = MemRefType::Builder(memRefType)
- .setShape(sizes)
- .setLayout(StridedLayoutAttr::get(
- memRefType.getContext(), offset, strides));
- return canonicalizeStridedLayout(stridedTy);
+ return MemRefType::Builder(memRefType)
+ .setShape(sizes)
+ .setLayout(
+ StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
}
void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
@@ -3212,15 +3211,15 @@ LogicalResult TransposeOp::verify() {
return emitOpError("expected a permutation map of same rank as the input");
auto srcType = llvm::cast<MemRefType>(getIn().getType());
- auto canonicalDstType =
- canonicalizeStridedLayout(llvm::cast<MemRefType>(getType()));
- auto inferedDstType = inferTransposeResultType(srcType, getPermutation());
-
- if (canonicalDstType != inferedDstType)
- return emitOpError("canonicalized output type ")
- << canonicalDstType
- << " does not match canonical transposed input type " << srcType
- << ", " << inferedDstType;
+ auto resultType = llvm::cast<MemRefType>(getType());
+ auto canonicalResultType = canonicalizeStridedLayout(
+ inferTransposeResultType(srcType, getPermutation()));
+
+ if (canonicalizeStridedLayout(resultType) != canonicalResultType)
+ return emitOpError("result type ")
+ << resultType
+ << " is not equivalent to the canonical transposed input type "
+ << canonicalResultType;
return success();
}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 25e08eda8f4dac..7bb7a2affcbd19 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -142,7 +142,7 @@ func.func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(o
// -----
func.func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
- // expected-error @+1 {{canonicalized output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match canonical transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>', 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 + s0 + d1 * s1)>>'}}
+ // expected-error @+1 {{result type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' is not equivalent to the canonical transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 + s0 + d1 * s1)>>'}}
memref.transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
}
More information about the Mlir-commits
mailing list