[Mlir-commits] [mlir] [mlir][memref] Transpose: allow affine map layouts in result, extend folder (PR #76294)

Felix Schneider llvmlistbot at llvm.org
Wed Jan 10 11:21:11 PST 2024


https://github.com/ubfx updated https://github.com/llvm/llvm-project/pull/76294

>From 7a5c1a5c0d63917ba14a5b6dbba7080373939f26 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 23 Dec 2023 18:02:40 +0100
Subject: [PATCH 1/4] [mlir][memref] Transpose: allow affine map layouts in
 result, extend folder

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp   | 48 ++++++++++++++--------
 mlir/test/Dialect/MemRef/canonicalize.mlir | 23 +++++++++++
 mlir/test/Dialect/MemRef/invalid.mlir      |  2 +-
 mlir/test/Dialect/MemRef/ops.mlir          |  6 +++
 4 files changed, 60 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a332fe253ba645..8d7cb6e1cc92cc 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3148,7 +3148,7 @@ void TransposeOp::getAsmResultNames(
   setNameFn(getResult(), "transpose");
 }
 
-/// Build a strided memref type by applying `permutationMap` tp `memRefType`.
+/// Build a strided memref type by applying `permutationMap` to `memRefType`.
 static MemRefType inferTransposeResultType(MemRefType memRefType,
                                            AffineMap permutationMap) {
   auto rank = memRefType.getRank();
@@ -3157,18 +3157,14 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
   assert(originalStrides.size() == static_cast<unsigned>(rank));
 
   // Compute permuted sizes and strides.
-  SmallVector<int64_t> sizes(rank, 0);
-  SmallVector<int64_t> strides(rank, 1);
-  for (const auto &en : llvm::enumerate(permutationMap.getResults())) {
-    unsigned position = cast<AffineDimExpr>(en.value()).getPosition();
-    sizes[en.index()] = originalSizes[position];
-    strides[en.index()] = originalStrides[position];
-  }
+  auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
+  auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
 
-  return MemRefType::Builder(memRefType)
-      .setShape(sizes)
-      .setLayout(
-          StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
+  auto stridedTy = MemRefType::Builder(memRefType)
+                       .setShape(sizes)
+                       .setLayout(StridedLayoutAttr::get(
+                           memRefType.getContext(), offset, strides));
+  return canonicalizeStridedLayout(stridedTy);
 }
 
 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
@@ -3216,18 +3212,34 @@ LogicalResult TransposeOp::verify() {
     return emitOpError("expected a permutation map of same rank as the input");
 
   auto srcType = llvm::cast<MemRefType>(getIn().getType());
-  auto dstType = llvm::cast<MemRefType>(getType());
-  auto transposedType = inferTransposeResultType(srcType, getPermutation());
-  if (dstType != transposedType)
-    return emitOpError("output type ")
-           << dstType << " does not match transposed input type " << srcType
-           << ", " << transposedType;
+  auto canonicalDstType =
+      canonicalizeStridedLayout(llvm::cast<MemRefType>(getType()));
+  auto inferedDstType = inferTransposeResultType(srcType, getPermutation());
+
+  if (canonicalDstType != inferedDstType)
+    return emitOpError("canonicalized output type ")
+           << canonicalDstType
+           << " does not match canonical transposed input type " << srcType
+           << ", " << inferedDstType;
   return success();
 }
 
 OpFoldResult TransposeOp::fold(FoldAdaptor) {
+  // First check for identity permutation, we can fold it away if input and
+  // result types are identical already.
+  if (getPermutation().isIdentity() && getType() == getIn().getType())
+    return getIn();
   if (succeeded(foldMemRefCast(*this)))
     return getResult();
+  // Fold two consecutive memref.transpose Ops into one by composing their
+  // permutation maps.
+  if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
+    AffineMap composedPermutation =
+        otherTransposeOp.getPermutation().compose(getPermutation());
+    getInMutable().assign(otherTransposeOp.getIn());
+    setPermutation(composedPermutation);
+    return getResult();
+  }
   return {};
 }
 
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index d3406c630f6dd7..3471a1f912e7ea 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -988,3 +988,26 @@ func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index)
   // CHECK: return %[[cast]]
   return %0 : memref<?x?xf32, strided<[384, 1], offset: ?>>
 }
+
+// -----
+
+// CHECK-LABEL: func @fold_double_transpose(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<1x2x3x4x5xf32>
+func.func @fold_double_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> {
+  // CHECK: %[[ONETRANSPOSE:.+]] = memref.transpose %[[arg0]] (d0, d1, d2, d3, d4) -> (d4, d2, d1, d3, d0)
+  %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>>
+  %1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
+  // CHECK: return %[[ONETRANSPOSE]]
+  return %1 : memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_identity_transpose(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<1x2x3x4x5xf32>
+func.func @fold_identity_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<1x2x3x4x5xf32> {
+  %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>>
+  %1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d0, d1, d2, d3, d4) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<1x2x3x4x5xf32>
+  // CHECK: return %[[arg0]]
+  return %1 : memref<1x2x3x4x5xf32>
+}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index f9b870f77266e1..25e08eda8f4dac 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -142,7 +142,7 @@ func.func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(o
 // -----
 
 func.func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
-  // expected-error @+1 {{output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>'}}
+  // expected-error @+1 {{canonicalized output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match canonical transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>', 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 + s0 + d1 * s1)>>'}}
   memref.transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
 }
 
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 7e2018ca58dc4a..a7730b71a0eacf 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -378,3 +378,9 @@ func.func @memref_memory_space_cast(%src : memref<?xf32>) -> memref<?xf32, 1> {
   %dst = memref.memory_space_cast %src : memref<?xf32> to memref<?xf32, 1>
   return %dst : memref<?xf32, 1>
 }
+
+// CHECK-LABEL: func @memref_transpose_map
+func.func @memref_transpose_map(%src : memref<?x?xf32>) -> memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>> {
+  %dst = memref.transpose %src (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
+  return %dst : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
+}
\ No newline at end of file

>From 80d87af6f0450c50f5dffd5968295c4053a742c0 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 23 Dec 2023 18:18:00 +0100
Subject: [PATCH 2/4] newline

---
 mlir/test/Dialect/MemRef/ops.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index a7730b71a0eacf..2d69904f27db5e 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -383,4 +383,4 @@ func.func @memref_memory_space_cast(%src : memref<?xf32>) -> memref<?xf32, 1> {
 func.func @memref_transpose_map(%src : memref<?x?xf32>) -> memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>> {
   %dst = memref.transpose %src (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
   return %dst : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
-}
\ No newline at end of file
+}

>From 0182df5a8c8c41374ffceba16da83db594156a1d Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 23 Dec 2023 19:38:25 +0100
Subject: [PATCH 3/4] compose the right way around and add a test that would
 have caught the issue

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp   |  2 +-
 mlir/test/Dialect/MemRef/canonicalize.mlir | 12 ++++++++++++
 2 files changed, 13 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 8d7cb6e1cc92cc..8b9fd34607219a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3235,7 +3235,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
   // permutation maps.
   if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
     AffineMap composedPermutation =
-        otherTransposeOp.getPermutation().compose(getPermutation());
+        getPermutation().compose(otherTransposeOp.getPermutation());
     getInMutable().assign(otherTransposeOp.getIn());
     setPermutation(composedPermutation);
     return getResult();
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 3471a1f912e7ea..eccfc485b2034e 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1003,6 +1003,18 @@ func.func @fold_double_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4
 
 // -----
 
+// CHECK-LABEL: func @fold_double_transpose2(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<1x2x3x4x5xf32>
+func.func @fold_double_transpose2(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> {
+  // CHECK: %[[ONETRANSPOSE:.+]] = memref.transpose %[[arg0]] (d0, d1, d2, d3, d4) -> (d4, d2, d1, d3, d0)
+  %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d0, d1, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<1x2x5x4x3xf32, strided<[120, 60, 1, 5, 20]>>
+  %1 = memref.transpose %0 (d0, d1, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<1x2x5x4x3xf32, strided<[120, 60, 1, 5, 20]>> to memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
+  // CHECK: return %[[ONETRANSPOSE]]
+  return %1 : memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_identity_transpose(
 //  CHECK-SAME:     %[[arg0:.*]]: memref<1x2x3x4x5xf32>
 func.func @fold_identity_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<1x2x3x4x5xf32> {

>From 5067ba1c9797ae2704d2f1675f5fadeaf2cc1fa7 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Wed, 10 Jan 2024 20:20:18 +0100
Subject: [PATCH 4/4] Infer un-canonicalized strided layout

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 27 ++++++++++++------------
 mlir/test/Dialect/MemRef/invalid.mlir    |  2 +-
 2 files changed, 14 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 8b9fd34607219a..b6c53838e615fd 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3160,11 +3160,10 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
   auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
   auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
 
-  auto stridedTy = MemRefType::Builder(memRefType)
-                       .setShape(sizes)
-                       .setLayout(StridedLayoutAttr::get(
-                           memRefType.getContext(), offset, strides));
-  return canonicalizeStridedLayout(stridedTy);
+  return MemRefType::Builder(memRefType)
+      .setShape(sizes)
+      .setLayout(
+          StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
 }
 
 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
@@ -3212,15 +3211,15 @@ LogicalResult TransposeOp::verify() {
     return emitOpError("expected a permutation map of same rank as the input");
 
   auto srcType = llvm::cast<MemRefType>(getIn().getType());
-  auto canonicalDstType =
-      canonicalizeStridedLayout(llvm::cast<MemRefType>(getType()));
-  auto inferedDstType = inferTransposeResultType(srcType, getPermutation());
-
-  if (canonicalDstType != inferedDstType)
-    return emitOpError("canonicalized output type ")
-           << canonicalDstType
-           << " does not match canonical transposed input type " << srcType
-           << ", " << inferedDstType;
+  auto resultType = llvm::cast<MemRefType>(getType());
+  auto canonicalResultType = canonicalizeStridedLayout(
+      inferTransposeResultType(srcType, getPermutation()));
+
+  if (canonicalizeStridedLayout(resultType) != canonicalResultType)
+    return emitOpError("result type ")
+           << resultType
+           << " is not equivalent to the canonical transposed input type "
+           << canonicalResultType;
   return success();
 }
 
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 25e08eda8f4dac..7bb7a2affcbd19 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -142,7 +142,7 @@ func.func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(o
 // -----
 
 func.func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
-  // expected-error @+1 {{canonicalized output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match canonical transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>', 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 + s0 + d1 * s1)>>'}}
+  // expected-error @+1 {{result type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' is not equivalent to the canonical transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 + s0 + d1 * s1)>>'}}
   memref.transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
 }
 



More information about the Mlir-commits mailing list