[Mlir-commits] [mlir] 4619e21 - [mlir][memref] Transpose: allow affine map layouts in result, extend folder (#76294)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 11 10:54:53 PST 2024


Author: Felix Schneider
Date: 2024-01-11T19:54:49+01:00
New Revision: 4619e21c72879591fbb5c8a1b1b5effe70b0a57e

URL: https://github.com/llvm/llvm-project/commit/4619e21c72879591fbb5c8a1b1b5effe70b0a57e
DIFF: https://github.com/llvm/llvm-project/commit/4619e21c72879591fbb5c8a1b1b5effe70b0a57e.diff

LOG: [mlir][memref] Transpose: allow affine map layouts in result, extend folder (#76294)

Currently, the `memref.transpose` verifier forces the result type of the
Op to have an explicit `StridedLayoutAttr` via the method
`inferTransposeResultType`. This means that the example Op
given in the documentation is actually invalid because it uses an `AffineMap`
to specify the layout.
It also means that we can't "un-transpose" a transposed memref back to
the implicit layout form, because the verifier will always enforce the
explicit strided layout.

This patch makes the following changes:

1. The verifier checks whether the canonicalized strided layout of the
result Type is identitcal to the canonicalized infered result type
layout. This way, it's only important that the two Types have the same
strided layout, not necessarily the same representation of it.
2. The folder is extended to support folding away the trivial case of
identity permutation and to fold one transposition into another by
composing the permutation maps.

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir
    mlir/test/Dialect/MemRef/invalid.mlir
    mlir/test/Dialect/MemRef/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a332fe253ba645..b6c53838e615fd 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3148,7 +3148,7 @@ void TransposeOp::getAsmResultNames(
   setNameFn(getResult(), "transpose");
 }
 
-/// Build a strided memref type by applying `permutationMap` tp `memRefType`.
+/// Build a strided memref type by applying `permutationMap` to `memRefType`.
 static MemRefType inferTransposeResultType(MemRefType memRefType,
                                            AffineMap permutationMap) {
   auto rank = memRefType.getRank();
@@ -3157,13 +3157,8 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
   assert(originalStrides.size() == static_cast<unsigned>(rank));
 
   // Compute permuted sizes and strides.
-  SmallVector<int64_t> sizes(rank, 0);
-  SmallVector<int64_t> strides(rank, 1);
-  for (const auto &en : llvm::enumerate(permutationMap.getResults())) {
-    unsigned position = cast<AffineDimExpr>(en.value()).getPosition();
-    sizes[en.index()] = originalSizes[position];
-    strides[en.index()] = originalStrides[position];
-  }
+  auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
+  auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
 
   return MemRefType::Builder(memRefType)
       .setShape(sizes)
@@ -3216,18 +3211,34 @@ LogicalResult TransposeOp::verify() {
     return emitOpError("expected a permutation map of same rank as the input");
 
   auto srcType = llvm::cast<MemRefType>(getIn().getType());
-  auto dstType = llvm::cast<MemRefType>(getType());
-  auto transposedType = inferTransposeResultType(srcType, getPermutation());
-  if (dstType != transposedType)
-    return emitOpError("output type ")
-           << dstType << " does not match transposed input type " << srcType
-           << ", " << transposedType;
+  auto resultType = llvm::cast<MemRefType>(getType());
+  auto canonicalResultType = canonicalizeStridedLayout(
+      inferTransposeResultType(srcType, getPermutation()));
+
+  if (canonicalizeStridedLayout(resultType) != canonicalResultType)
+    return emitOpError("result type ")
+           << resultType
+           << " is not equivalent to the canonical transposed input type "
+           << canonicalResultType;
   return success();
 }
 
 OpFoldResult TransposeOp::fold(FoldAdaptor) {
+  // First check for identity permutation, we can fold it away if input and
+  // result types are identical already.
+  if (getPermutation().isIdentity() && getType() == getIn().getType())
+    return getIn();
   if (succeeded(foldMemRefCast(*this)))
     return getResult();
+  // Fold two consecutive memref.transpose Ops into one by composing their
+  // permutation maps.
+  if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
+    AffineMap composedPermutation =
+        getPermutation().compose(otherTransposeOp.getPermutation());
+    getInMutable().assign(otherTransposeOp.getIn());
+    setPermutation(composedPermutation);
+    return getResult();
+  }
   return {};
 }
 

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index d3406c630f6dd7..eccfc485b2034e 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -988,3 +988,38 @@ func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index)
   // CHECK: return %[[cast]]
   return %0 : memref<?x?xf32, strided<[384, 1], offset: ?>>
 }
+
+// -----
+
+// CHECK-LABEL: func @fold_double_transpose(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<1x2x3x4x5xf32>
+func.func @fold_double_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> {
+  // CHECK: %[[ONETRANSPOSE:.+]] = memref.transpose %[[arg0]] (d0, d1, d2, d3, d4) -> (d4, d2, d1, d3, d0)
+  %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>>
+  %1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
+  // CHECK: return %[[ONETRANSPOSE]]
+  return %1 : memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_double_transpose2(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<1x2x3x4x5xf32>
+func.func @fold_double_transpose2(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> {
+  // CHECK: %[[ONETRANSPOSE:.+]] = memref.transpose %[[arg0]] (d0, d1, d2, d3, d4) -> (d4, d2, d1, d3, d0)
+  %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d0, d1, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<1x2x5x4x3xf32, strided<[120, 60, 1, 5, 20]>>
+  %1 = memref.transpose %0 (d0, d1, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<1x2x5x4x3xf32, strided<[120, 60, 1, 5, 20]>> to memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
+  // CHECK: return %[[ONETRANSPOSE]]
+  return %1 : memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_identity_transpose(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<1x2x3x4x5xf32>
+func.func @fold_identity_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<1x2x3x4x5xf32> {
+  %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>>
+  %1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d0, d1, d2, d3, d4) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<1x2x3x4x5xf32>
+  // CHECK: return %[[arg0]]
+  return %1 : memref<1x2x3x4x5xf32>
+}

diff  --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index f9b870f77266e1..7bb7a2affcbd19 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -142,7 +142,7 @@ func.func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(o
 // -----
 
 func.func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
-  // expected-error @+1 {{output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>'}}
+  // expected-error @+1 {{result type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' is not equivalent to the canonical transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 + s0 + d1 * s1)>>'}}
   memref.transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
 }
 

diff  --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 7e2018ca58dc4a..2d69904f27db5e 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -378,3 +378,9 @@ func.func @memref_memory_space_cast(%src : memref<?xf32>) -> memref<?xf32, 1> {
   %dst = memref.memory_space_cast %src : memref<?xf32> to memref<?xf32, 1>
   return %dst : memref<?xf32, 1>
 }
+
+// CHECK-LABEL: func @memref_transpose_map
+func.func @memref_transpose_map(%src : memref<?x?xf32>) -> memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>> {
+  %dst = memref.transpose %src (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
+  return %dst : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
+}


        


More information about the Mlir-commits mailing list