[Mlir-commits] [mlir] [mlir][memref] Use array notation instead of permutation map for memref.transpose (PR #67880)

Felix Schneider llvmlistbot at llvm.org
Sat Sep 30 04:36:51 PDT 2023


https://github.com/ubfx created https://github.com/llvm/llvm-project/pull/67880

Until now, the dimensional permutation for memref.transpose was given in the form of an affine map. However, just from looking at such a representation, e.g. `(i, j) -> (j, i)`, it's not obvious whether it represents a mapping from the result dimensions to the source dimensions or the other way around. This has led to a bug (#65145).

This patch introduces to `memref.transpose` the integer array based notation that is also used in Ops like `linalg.transpose`, `memref.collapse_shape` and others which is harder to misinterpret and easier to work with.

>From 989cf7336902490262d48c87477cf423f450c903 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 30 Sep 2023 13:28:58 +0200
Subject: [PATCH] [mlir][memref] Use array notation instead of permutation map
 for memref.transpose

Until now, the dimensional permutation for memref.transpose was given in the form of an affine map. However, just from looking at such a representation, e.g. `(i, j) -> (j, i)`, it's not obvious whether it represents a mapping from the result dimensions to the source dimensions or the other way around. This has led to a bug (#65145).

This patch introduces to `memref.transpose` the integer array based notation that is also used in Ops like `linalg.transpose`, `memref.collapse_shape` and others which is harder to misinterpret and easier to work with.
---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       | 19 ++++-
 .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp  | 31 ++++---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 83 +++++++++++--------
 .../MemRefToLLVM/memref-to-llvm.mlir          |  2 +-
 mlir/test/Dialect/Linalg/roundtrip.mlir       |  4 +-
 mlir/test/Dialect/MemRef/invalid.mlir         |  6 +-
 6 files changed, 85 insertions(+), 60 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index ea6e363a6c3257f..30eb3feb097bf81 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -2119,7 +2119,7 @@ def TensorStoreOp : MemRef_Op<"tensor_store",
 def MemRef_TransposeOp : MemRef_Op<"transpose", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     Pure]>,
-    Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
+    Arguments<(ins AnyStridedMemRef:$in, DenseI64ArrayAttr:$permutation)>,
     Results<(outs AnyStridedMemRef)> {
   let summary = "`transpose` produces a new strided memref (metadata-only)";
   let description = [{
@@ -2127,22 +2127,33 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
     are a permutation of the original `in` memref. This is purely a metadata
     transformation.
 
+    The permutation is given in the form of an array of indices following the rule:
+    `dim(result, i) = dim(input, permutation[i])`
+
     Example:
 
     ```mlir
-    %1 = memref.transpose %0 (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
+    %1 = memref.transpose %0 [1, 0] : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
     ```
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$in, "AffineMapAttr":$permutation,
+    OpBuilder<(ins "Value":$in, "DenseI64ArrayAttr":$permutation,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    OpBuilder<(ins "Value":$in, "ArrayRef<int64_t>":$permutation,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
 
   let extraClassDeclaration = [{
     static StringRef getPermutationAttrStrName() { return "permutation"; }
+
+    /// Returns true if the permutation represents an identity permutation
+    bool isIdentity();
+  }];
+
+  let assemblyFormat = [{
+    $in $permutation attr-dict `:` type($in) `to` type(results)
   }];
 
-  let hasCustomAssemblyFormat = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 61bd23f12601c79..2e34b690ae7a55f 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1410,38 +1410,35 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
     MemRefDescriptor viewMemRef(adaptor.getIn());
 
     // No permutation, early exit.
-    if (transposeOp.getPermutation().isIdentity())
+    if (transposeOp.isIdentity())
       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
 
-    auto targetMemRef = MemRefDescriptor::undef(
+    auto resultMemRef = MemRefDescriptor::undef(
         rewriter, loc,
         typeConverter->convertType(transposeOp.getIn().getType()));
 
     // Copy the base and aligned pointers from the old descriptor to the new
     // one.
-    targetMemRef.setAllocatedPtr(rewriter, loc,
+    resultMemRef.setAllocatedPtr(rewriter, loc,
                                  viewMemRef.allocatedPtr(rewriter, loc));
-    targetMemRef.setAlignedPtr(rewriter, loc,
+    resultMemRef.setAlignedPtr(rewriter, loc,
                                viewMemRef.alignedPtr(rewriter, loc));
 
     // Copy the offset pointer from the old descriptor to the new one.
-    targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
+    resultMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
 
     // Iterate over the dimensions and apply size/stride permutation:
-    // When enumerating the results of the permutation map, the enumeration index
-    // is the index into the target dimensions and the DimExpr points to the
-    // dimension of the source memref.
-    for (const auto &en :
-         llvm::enumerate(transposeOp.getPermutation().getResults())) {
-      int targetPos = en.index();
-      int sourcePos = en.value().cast<AffineDimExpr>().getPosition();
-      targetMemRef.setSize(rewriter, loc, targetPos,
-                           viewMemRef.size(rewriter, loc, sourcePos));
-      targetMemRef.setStride(rewriter, loc, targetPos,
-                             viewMemRef.stride(rewriter, loc, sourcePos));
+    ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+    for (int64_t resultDimPos = 0, rank = permutation.size();
+         resultDimPos < rank; ++resultDimPos) {
+      int originalDimPos = permutation[resultDimPos];
+      resultMemRef.setSize(rewriter, loc, resultDimPos,
+                           viewMemRef.size(rewriter, loc, originalDimPos));
+      resultMemRef.setStride(rewriter, loc, resultDimPos,
+                             viewMemRef.stride(rewriter, loc, originalDimPos));
     }
 
-    rewriter.replaceOp(transposeOp, {targetMemRef});
+    rewriter.replaceOp(transposeOp, {resultMemRef});
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 215a8f5e7d18be0..fa28c850aea960a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3176,21 +3176,22 @@ void TransposeOp::getAsmResultNames(
   setNameFn(getResult(), "transpose");
 }
 
-/// Build a strided memref type by applying `permutationMap` tp `memRefType`.
+/// Build a strided memref type by applying `permutation` tp `memRefType`.
 static MemRefType inferTransposeResultType(MemRefType memRefType,
-                                           AffineMap permutationMap) {
+                                           ArrayRef<int64_t> permutation) {
   auto rank = memRefType.getRank();
   auto originalSizes = memRefType.getShape();
   auto [originalStrides, offset] = getStridesAndOffset(memRefType);
   assert(originalStrides.size() == static_cast<unsigned>(rank));
+  assert(permutation.size() == rank);
 
   // Compute permuted sizes and strides.
   SmallVector<int64_t> sizes(rank, 0);
   SmallVector<int64_t> strides(rank, 1);
-  for (const auto &en : llvm::enumerate(permutationMap.getResults())) {
-    unsigned position = en.value().cast<AffineDimExpr>().getPosition();
-    sizes[en.index()] = originalSizes[position];
-    strides[en.index()] = originalStrides[position];
+  for (int64_t resultDimPos = 0; resultDimPos < rank; ++resultDimPos) {
+    int64_t originalDimPos = permutation[resultDimPos];
+    sizes[resultDimPos] = originalSizes[originalDimPos];
+    strides[resultDimPos] = originalStrides[originalDimPos];
   }
 
   return MemRefType::Builder(memRefType)
@@ -3200,52 +3201,59 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
 }
 
 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
-                        AffineMapAttr permutation,
+                        DenseI64ArrayAttr permutation,
                         ArrayRef<NamedAttribute> attrs) {
-  auto permutationMap = permutation.getValue();
-  assert(permutationMap);
-
   auto memRefType = llvm::cast<MemRefType>(in.getType());
   // Compute result type.
-  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
+  MemRefType resultType =
+      inferTransposeResultType(memRefType, permutation.asArrayRef());
 
   build(b, result, resultType, in, attrs);
   result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
 }
 
-// transpose $in $permutation attr-dict : type($in) `to` type(results)
-void TransposeOp::print(OpAsmPrinter &p) {
-  p << " " << getIn() << " " << getPermutation();
-  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
-  p << " : " << getIn().getType() << " to " << getType();
+void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
+                        ArrayRef<int64_t> permutation,
+                        ArrayRef<NamedAttribute> attrs) {
+  auto memRefType = llvm::cast<MemRefType>(in.getType());
+  // Compute result type.
+  MemRefType resultType = inferTransposeResultType(memRefType, permutation);
+
+  build(b, result, resultType, in, attrs);
+  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
+                      b.getDenseI64ArrayAttr(permutation));
 }
 
-ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
-  OpAsmParser::UnresolvedOperand in;
-  AffineMap permutation;
-  MemRefType srcType, dstType;
-  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
-      parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonType(srcType) ||
-      parser.resolveOperand(in, srcType, result.operands) ||
-      parser.parseKeywordType("to", dstType) ||
-      parser.addTypeToList(dstType, result.types))
-    return failure();
+/// Check whether the supplied array is an permutation index array, i.e. it
+/// contains the elements 0..size()-1.
+static bool isPermutationArray(ArrayRef<int64_t> arr) {
+  for (int64_t i = 0, e = arr.size(); i < e; ++i) {
+    bool found = false;
+    for (int64_t j = 0; j < e; ++j) {
+      if (arr[j] == i) {
+        found = true;
+        break;
+      }
+    }
 
-  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
-                      AffineMapAttr::get(permutation));
-  return success();
+    if (!found)
+      return false;
+  }
+
+  return true;
 }
 
 LogicalResult TransposeOp::verify() {
-  if (!getPermutation().isPermutation())
+  ArrayRef<int64_t> permutation = getPermutation();
+
+  if (!isPermutationArray(permutation))
     return emitOpError("expected a permutation map");
-  if (getPermutation().getNumDims() != getIn().getType().getRank())
+  if (permutation.size() != getIn().getType().getRank())
     return emitOpError("expected a permutation map of same rank as the input");
 
   auto srcType = llvm::cast<MemRefType>(getIn().getType());
   auto dstType = llvm::cast<MemRefType>(getType());
-  auto transposedType = inferTransposeResultType(srcType, getPermutation());
+  auto transposedType = inferTransposeResultType(srcType, permutation);
   if (dstType != transposedType)
     return emitOpError("output type ")
            << dstType << " does not match transposed input type " << srcType
@@ -3259,6 +3267,15 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
   return {};
 }
 
+bool TransposeOp::isIdentity() {
+  ArrayRef<int64_t> permutationArray = getPermutation();
+  for (int64_t i = 0, rank = permutationArray.size(); i < rank; ++i)
+    if (permutationArray[i] != i)
+      return false;
+
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // ViewOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 9e44029ad93bd9c..355c9d494208212 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -243,7 +243,7 @@ func.func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) {
 //       CHECK:   llvm.extractvalue {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 //       CHECK:    llvm.insertvalue {{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 func.func @transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
-  %0 = memref.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
+  %0 = memref.transpose %arg0 [2, 0, 1] : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
   return
 }
 
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 6203cf1c76d144c..b909b46095f053e 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -68,11 +68,11 @@ func.func @fill_view(%arg0: memref<?xf32, strided<[1], offset: ?>>, %arg1: f32)
 // -----
 
 func.func @memref_transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
-  %0 = memref.transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
+  %0 = memref.transpose %arg0 [2, 1, 0] : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
   return
 }
 // CHECK-LABEL: func @memref_transpose
-//       CHECK:   memref.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) :
+//       CHECK:   memref.transpose %{{.*}} [2, 1, 0] :
 //  CHECK-SAME:      memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
 
 // -----
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index cb5977e302a993f..e4037b4bf0cdc2a 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -129,21 +129,21 @@ func.func @dma_wait_wrong_index_type(%tag : memref<2x2xi32>, %idx: index, %flt:
 
 func.func @transpose_not_permutation(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
   // expected-error @+1 {{expected a permutation map}}
-  memref.transpose %v (i, j) -> (i, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+  memref.transpose %v [1, 1] : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
 }
 
 // -----
 
 func.func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
   // expected-error @+1 {{expected a permutation map of same rank as the input}}
-  memref.transpose %v (i) -> (i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+  memref.transpose %v [0] : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
 }
 
 // -----
 
 func.func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
   // expected-error @+1 {{output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>'}}
-  memref.transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+  memref.transpose %v [1, 0] : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
 }
 
 // -----



More information about the Mlir-commits mailing list