[Mlir-commits] [mlir] [mlir][memref] Use array notation instead of permutation map for memref.transpose (PR #67880)

Felix Schneider llvmlistbot at llvm.org
Sat Sep 30 06:45:42 PDT 2023


https://github.com/ubfx updated https://github.com/llvm/llvm-project/pull/67880

>From 64ca3c026abd19a78263dd9a70acccbdbaa5e397 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 30 Sep 2023 11:33:22 +0000
Subject: [PATCH] [mlir][memref] Use array notation instead of permutation map
 for memref.transpose

Until now, the dimensional permutation for memref.transpose was given in the
form of an affine map. However, just from looking at such a representation,
e.g. `(i, j) -> (j, i)`, it's not obvious whether it represents a mapping from
the result dimensions to the source dimensions or the other way around. This has
led to a bug (#65145).

This patch introduces to `memref.transpose` the integer array based notation
that is also used in Ops like `linalg.transpose`, `memref.collapse_shape`
and others which is harder to misinterpret and easier to work with.
---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       | 19 ++++-
 .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp  | 31 ++++---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 83 +++++++++++--------
 .../MemRefToLLVM/memref-to-llvm.mlir          |  2 +-
 mlir/test/Dialect/Linalg/roundtrip.mlir       |  4 +-
 mlir/test/Dialect/MemRef/invalid.mlir         |  6 +-
 6 files changed, 85 insertions(+), 60 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 6b0ccbe37e89e9c..067737d8a18fda9 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -2125,7 +2125,7 @@ def TensorStoreOp : MemRef_Op<"tensor_store",
 def MemRef_TransposeOp : MemRef_Op<"transpose", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     Pure]>,
-    Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
+    Arguments<(ins AnyStridedMemRef:$in, DenseI64ArrayAttr:$permutation)>,
     Results<(outs AnyStridedMemRef)> {
   let summary = "`transpose` produces a new strided memref (metadata-only)";
   let description = [{
@@ -2133,22 +2133,33 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
     are a permutation of the original `in` memref. This is purely a metadata
     transformation.
 
+    The permutation is given in the form of an array of indices following the rule:
+    `dim(result, i) = dim(input, permutation[i])`
+
     Example:
 
     ```mlir
-    %1 = memref.transpose %0 (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
+    %1 = memref.transpose %0 [1, 0] : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
     ```
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$in, "AffineMapAttr":$permutation,
+    OpBuilder<(ins "Value":$in, "DenseI64ArrayAttr":$permutation,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    OpBuilder<(ins "Value":$in, "ArrayRef<int64_t>":$permutation,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
 
   let extraClassDeclaration = [{
     static StringRef getPermutationAttrStrName() { return "permutation"; }
+
+    /// Returns true if the permutation represents an identity permutation
+    bool isIdentity();
+  }];
+
+  let assemblyFormat = [{
+    $in $permutation attr-dict `:` type($in) `to` type(results)
   }];
 
-  let hasCustomAssemblyFormat = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 61bd23f12601c79..2e34b690ae7a55f 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1410,38 +1410,35 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
     MemRefDescriptor viewMemRef(adaptor.getIn());
 
     // No permutation, early exit.
-    if (transposeOp.getPermutation().isIdentity())
+    if (transposeOp.isIdentity())
       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
 
-    auto targetMemRef = MemRefDescriptor::undef(
+    auto resultMemRef = MemRefDescriptor::undef(
         rewriter, loc,
         typeConverter->convertType(transposeOp.getIn().getType()));
 
     // Copy the base and aligned pointers from the old descriptor to the new
     // one.
-    targetMemRef.setAllocatedPtr(rewriter, loc,
+    resultMemRef.setAllocatedPtr(rewriter, loc,
                                  viewMemRef.allocatedPtr(rewriter, loc));
-    targetMemRef.setAlignedPtr(rewriter, loc,
+    resultMemRef.setAlignedPtr(rewriter, loc,
                                viewMemRef.alignedPtr(rewriter, loc));
 
     // Copy the offset pointer from the old descriptor to the new one.
-    targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
+    resultMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
 
     // Iterate over the dimensions and apply size/stride permutation:
-    // When enumerating the results of the permutation map, the enumeration index
-    // is the index into the target dimensions and the DimExpr points to the
-    // dimension of the source memref.
-    for (const auto &en :
-         llvm::enumerate(transposeOp.getPermutation().getResults())) {
-      int targetPos = en.index();
-      int sourcePos = en.value().cast<AffineDimExpr>().getPosition();
-      targetMemRef.setSize(rewriter, loc, targetPos,
-                           viewMemRef.size(rewriter, loc, sourcePos));
-      targetMemRef.setStride(rewriter, loc, targetPos,
-                             viewMemRef.stride(rewriter, loc, sourcePos));
+    ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+    for (int64_t resultDimPos = 0, rank = permutation.size();
+         resultDimPos < rank; ++resultDimPos) {
+      int originalDimPos = permutation[resultDimPos];
+      resultMemRef.setSize(rewriter, loc, resultDimPos,
+                           viewMemRef.size(rewriter, loc, originalDimPos));
+      resultMemRef.setStride(rewriter, loc, resultDimPos,
+                             viewMemRef.stride(rewriter, loc, originalDimPos));
     }
 
-    rewriter.replaceOp(transposeOp, {targetMemRef});
+    rewriter.replaceOp(transposeOp, {resultMemRef});
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 215a8f5e7d18be0..fa28c850aea960a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3176,21 +3176,22 @@ void TransposeOp::getAsmResultNames(
   setNameFn(getResult(), "transpose");
 }
 
-/// Build a strided memref type by applying `permutationMap` tp `memRefType`.
+/// Build a strided memref type by applying `permutation` tp `memRefType`.
 static MemRefType inferTransposeResultType(MemRefType memRefType,
-                                           AffineMap permutationMap) {
+                                           ArrayRef<int64_t> permutation) {
   auto rank = memRefType.getRank();
   auto originalSizes = memRefType.getShape();
   auto [originalStrides, offset] = getStridesAndOffset(memRefType);
   assert(originalStrides.size() == static_cast<unsigned>(rank));
+  assert(permutation.size() == rank);
 
   // Compute permuted sizes and strides.
   SmallVector<int64_t> sizes(rank, 0);
   SmallVector<int64_t> strides(rank, 1);
-  for (const auto &en : llvm::enumerate(permutationMap.getResults())) {
-    unsigned position = en.value().cast<AffineDimExpr>().getPosition();
-    sizes[en.index()] = originalSizes[position];
-    strides[en.index()] = originalStrides[position];
+  for (int64_t resultDimPos = 0; resultDimPos < rank; ++resultDimPos) {
+    int64_t originalDimPos = permutation[resultDimPos];
+    sizes[resultDimPos] = originalSizes[originalDimPos];
+    strides[resultDimPos] = originalStrides[originalDimPos];
   }
 
   return MemRefType::Builder(memRefType)
@@ -3200,52 +3201,59 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
 }
 
 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
-                        AffineMapAttr permutation,
+                        DenseI64ArrayAttr permutation,
                         ArrayRef<NamedAttribute> attrs) {
-  auto permutationMap = permutation.getValue();
-  assert(permutationMap);
-
   auto memRefType = llvm::cast<MemRefType>(in.getType());
   // Compute result type.
-  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
+  MemRefType resultType =
+      inferTransposeResultType(memRefType, permutation.asArrayRef());
 
   build(b, result, resultType, in, attrs);
   result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
 }
 
-// transpose $in $permutation attr-dict : type($in) `to` type(results)
-void TransposeOp::print(OpAsmPrinter &p) {
-  p << " " << getIn() << " " << getPermutation();
-  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
-  p << " : " << getIn().getType() << " to " << getType();
+void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
+                        ArrayRef<int64_t> permutation,
+                        ArrayRef<NamedAttribute> attrs) {
+  auto memRefType = llvm::cast<MemRefType>(in.getType());
+  // Compute result type.
+  MemRefType resultType = inferTransposeResultType(memRefType, permutation);
+
+  build(b, result, resultType, in, attrs);
+  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
+                      b.getDenseI64ArrayAttr(permutation));
 }
 
-ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
-  OpAsmParser::UnresolvedOperand in;
-  AffineMap permutation;
-  MemRefType srcType, dstType;
-  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
-      parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonType(srcType) ||
-      parser.resolveOperand(in, srcType, result.operands) ||
-      parser.parseKeywordType("to", dstType) ||
-      parser.addTypeToList(dstType, result.types))
-    return failure();
+/// Check whether the supplied array is an permutation index array, i.e. it
+/// contains the elements 0..size()-1.
+static bool isPermutationArray(ArrayRef<int64_t> arr) {
+  for (int64_t i = 0, e = arr.size(); i < e; ++i) {
+    bool found = false;
+    for (int64_t j = 0; j < e; ++j) {
+      if (arr[j] == i) {
+        found = true;
+        break;
+      }
+    }
 
-  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
-                      AffineMapAttr::get(permutation));
-  return success();
+    if (!found)
+      return false;
+  }
+
+  return true;
 }
 
 LogicalResult TransposeOp::verify() {
-  if (!getPermutation().isPermutation())
+  ArrayRef<int64_t> permutation = getPermutation();
+
+  if (!isPermutationArray(permutation))
     return emitOpError("expected a permutation map");
-  if (getPermutation().getNumDims() != getIn().getType().getRank())
+  if (permutation.size() != getIn().getType().getRank())
     return emitOpError("expected a permutation map of same rank as the input");
 
   auto srcType = llvm::cast<MemRefType>(getIn().getType());
   auto dstType = llvm::cast<MemRefType>(getType());
-  auto transposedType = inferTransposeResultType(srcType, getPermutation());
+  auto transposedType = inferTransposeResultType(srcType, permutation);
   if (dstType != transposedType)
     return emitOpError("output type ")
            << dstType << " does not match transposed input type " << srcType
@@ -3259,6 +3267,15 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
   return {};
 }
 
+bool TransposeOp::isIdentity() {
+  ArrayRef<int64_t> permutationArray = getPermutation();
+  for (int64_t i = 0, rank = permutationArray.size(); i < rank; ++i)
+    if (permutationArray[i] != i)
+      return false;
+
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // ViewOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index ae487ef6694745d..0d1b2219d6ac094 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -243,7 +243,7 @@ func.func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) {
 //       CHECK:   llvm.extractvalue {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 //       CHECK:    llvm.insertvalue {{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 func.func @transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
-  %0 = memref.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
+  %0 = memref.transpose %arg0 [2, 0, 1] : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
   return
 }
 
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 6203cf1c76d144c..b909b46095f053e 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -68,11 +68,11 @@ func.func @fill_view(%arg0: memref<?xf32, strided<[1], offset: ?>>, %arg1: f32)
 // -----
 
 func.func @memref_transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
-  %0 = memref.transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
+  %0 = memref.transpose %arg0 [2, 1, 0] : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
   return
 }
 // CHECK-LABEL: func @memref_transpose
-//       CHECK:   memref.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) :
+//       CHECK:   memref.transpose %{{.*}} [2, 1, 0] :
 //  CHECK-SAME:      memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
 
 // -----
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index cb5977e302a993f..e4037b4bf0cdc2a 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -129,21 +129,21 @@ func.func @dma_wait_wrong_index_type(%tag : memref<2x2xi32>, %idx: index, %flt:
 
 func.func @transpose_not_permutation(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
   // expected-error @+1 {{expected a permutation map}}
-  memref.transpose %v (i, j) -> (i, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+  memref.transpose %v [1, 1] : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
 }
 
 // -----
 
 func.func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
   // expected-error @+1 {{expected a permutation map of same rank as the input}}
-  memref.transpose %v (i) -> (i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+  memref.transpose %v [0] : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
 }
 
 // -----
 
 func.func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
   // expected-error @+1 {{output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>'}}
-  memref.transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+  memref.transpose %v [1, 0] : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
 }
 
 // -----



More information about the Mlir-commits mailing list