[Mlir-commits] [mlir] [mlir][memref] Use array notation instead of permutation map for memref.transpose (PR #67880)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Sep 30 04:37:58 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-memref

<details>
<summary>Changes</summary>

Until now, the dimensional permutation for memref.transpose was given in the form of an affine map. However, just from looking at such a representation, e.g. `(i, j) -> (j, i)`, it's not obvious whether it represents a mapping from the result dimensions to the source dimensions or the other way around. This has led to a bug (#<!-- -->65145).

This patch introduces to `memref.transpose` the integer array based notation that is also used in Ops like `linalg.transpose`, `memref.collapse_shape` and others which is harder to misinterpret and easier to work with.

---
Full diff: https://github.com/llvm/llvm-project/pull/67880.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+15-4) 
- (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+14-17) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+50-33) 
- (modified) mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir (+1-1) 
- (modified) mlir/test/Dialect/Linalg/roundtrip.mlir (+2-2) 
- (modified) mlir/test/Dialect/MemRef/invalid.mlir (+3-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index ea6e363a6c3257f..30eb3feb097bf81 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -2119,7 +2119,7 @@ def TensorStoreOp : MemRef_Op<"tensor_store",
 def MemRef_TransposeOp : MemRef_Op<"transpose", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     Pure]>,
-    Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
+    Arguments<(ins AnyStridedMemRef:$in, DenseI64ArrayAttr:$permutation)>,
     Results<(outs AnyStridedMemRef)> {
   let summary = "`transpose` produces a new strided memref (metadata-only)";
   let description = [{
@@ -2127,22 +2127,33 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
     are a permutation of the original `in` memref. This is purely a metadata
     transformation.
 
+    The permutation is given in the form of an array of indices following the rule:
+    `dim(result, i) = dim(input, permutation[i])`
+
     Example:
 
     ```mlir
-    %1 = memref.transpose %0 (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
+    %1 = memref.transpose %0 [1, 0] : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
     ```
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$in, "AffineMapAttr":$permutation,
+    OpBuilder<(ins "Value":$in, "DenseI64ArrayAttr":$permutation,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    OpBuilder<(ins "Value":$in, "ArrayRef<int64_t>":$permutation,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
 
   let extraClassDeclaration = [{
     static StringRef getPermutationAttrStrName() { return "permutation"; }
+
+    /// Returns true if the permutation represents an identity permutation
+    bool isIdentity();
+  }];
+
+  let assemblyFormat = [{
+    $in $permutation attr-dict `:` type($in) `to` type(results)
   }];
 
-  let hasCustomAssemblyFormat = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 61bd23f12601c79..2e34b690ae7a55f 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1410,38 +1410,35 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
     MemRefDescriptor viewMemRef(adaptor.getIn());
 
     // No permutation, early exit.
-    if (transposeOp.getPermutation().isIdentity())
+    if (transposeOp.isIdentity())
       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
 
-    auto targetMemRef = MemRefDescriptor::undef(
+    auto resultMemRef = MemRefDescriptor::undef(
         rewriter, loc,
         typeConverter->convertType(transposeOp.getIn().getType()));
 
     // Copy the base and aligned pointers from the old descriptor to the new
     // one.
-    targetMemRef.setAllocatedPtr(rewriter, loc,
+    resultMemRef.setAllocatedPtr(rewriter, loc,
                                  viewMemRef.allocatedPtr(rewriter, loc));
-    targetMemRef.setAlignedPtr(rewriter, loc,
+    resultMemRef.setAlignedPtr(rewriter, loc,
                                viewMemRef.alignedPtr(rewriter, loc));
 
     // Copy the offset pointer from the old descriptor to the new one.
-    targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
+    resultMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
 
     // Iterate over the dimensions and apply size/stride permutation:
-    // When enumerating the results of the permutation map, the enumeration index
-    // is the index into the target dimensions and the DimExpr points to the
-    // dimension of the source memref.
-    for (const auto &en :
-         llvm::enumerate(transposeOp.getPermutation().getResults())) {
-      int targetPos = en.index();
-      int sourcePos = en.value().cast<AffineDimExpr>().getPosition();
-      targetMemRef.setSize(rewriter, loc, targetPos,
-                           viewMemRef.size(rewriter, loc, sourcePos));
-      targetMemRef.setStride(rewriter, loc, targetPos,
-                             viewMemRef.stride(rewriter, loc, sourcePos));
+    ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+    for (int64_t resultDimPos = 0, rank = permutation.size();
+         resultDimPos < rank; ++resultDimPos) {
+      int originalDimPos = permutation[resultDimPos];
+      resultMemRef.setSize(rewriter, loc, resultDimPos,
+                           viewMemRef.size(rewriter, loc, originalDimPos));
+      resultMemRef.setStride(rewriter, loc, resultDimPos,
+                             viewMemRef.stride(rewriter, loc, originalDimPos));
     }
 
-    rewriter.replaceOp(transposeOp, {targetMemRef});
+    rewriter.replaceOp(transposeOp, {resultMemRef});
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 215a8f5e7d18be0..fa28c850aea960a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3176,21 +3176,22 @@ void TransposeOp::getAsmResultNames(
   setNameFn(getResult(), "transpose");
 }
 
-/// Build a strided memref type by applying `permutationMap` tp `memRefType`.
+/// Build a strided memref type by applying `permutation` tp `memRefType`.
 static MemRefType inferTransposeResultType(MemRefType memRefType,
-                                           AffineMap permutationMap) {
+                                           ArrayRef<int64_t> permutation) {
   auto rank = memRefType.getRank();
   auto originalSizes = memRefType.getShape();
   auto [originalStrides, offset] = getStridesAndOffset(memRefType);
   assert(originalStrides.size() == static_cast<unsigned>(rank));
+  assert(permutation.size() == rank);
 
   // Compute permuted sizes and strides.
   SmallVector<int64_t> sizes(rank, 0);
   SmallVector<int64_t> strides(rank, 1);
-  for (const auto &en : llvm::enumerate(permutationMap.getResults())) {
-    unsigned position = en.value().cast<AffineDimExpr>().getPosition();
-    sizes[en.index()] = originalSizes[position];
-    strides[en.index()] = originalStrides[position];
+  for (int64_t resultDimPos = 0; resultDimPos < rank; ++resultDimPos) {
+    int64_t originalDimPos = permutation[resultDimPos];
+    sizes[resultDimPos] = originalSizes[originalDimPos];
+    strides[resultDimPos] = originalStrides[originalDimPos];
   }
 
   return MemRefType::Builder(memRefType)
@@ -3200,52 +3201,59 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
 }
 
 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
-                        AffineMapAttr permutation,
+                        DenseI64ArrayAttr permutation,
                         ArrayRef<NamedAttribute> attrs) {
-  auto permutationMap = permutation.getValue();
-  assert(permutationMap);
-
   auto memRefType = llvm::cast<MemRefType>(in.getType());
   // Compute result type.
-  MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
+  MemRefType resultType =
+      inferTransposeResultType(memRefType, permutation.asArrayRef());
 
   build(b, result, resultType, in, attrs);
   result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
 }
 
-// transpose $in $permutation attr-dict : type($in) `to` type(results)
-void TransposeOp::print(OpAsmPrinter &p) {
-  p << " " << getIn() << " " << getPermutation();
-  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
-  p << " : " << getIn().getType() << " to " << getType();
+void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
+                        ArrayRef<int64_t> permutation,
+                        ArrayRef<NamedAttribute> attrs) {
+  auto memRefType = llvm::cast<MemRefType>(in.getType());
+  // Compute result type.
+  MemRefType resultType = inferTransposeResultType(memRefType, permutation);
+
+  build(b, result, resultType, in, attrs);
+  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
+                      b.getDenseI64ArrayAttr(permutation));
 }
 
-ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
-  OpAsmParser::UnresolvedOperand in;
-  AffineMap permutation;
-  MemRefType srcType, dstType;
-  if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
-      parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseColonType(srcType) ||
-      parser.resolveOperand(in, srcType, result.operands) ||
-      parser.parseKeywordType("to", dstType) ||
-      parser.addTypeToList(dstType, result.types))
-    return failure();
+/// Check whether the supplied array is an permutation index array, i.e. it
+/// contains the elements 0..size()-1.
+static bool isPermutationArray(ArrayRef<int64_t> arr) {
+  for (int64_t i = 0, e = arr.size(); i < e; ++i) {
+    bool found = false;
+    for (int64_t j = 0; j < e; ++j) {
+      if (arr[j] == i) {
+        found = true;
+        break;
+      }
+    }
 
-  result.addAttribute(TransposeOp::getPermutationAttrStrName(),
-                      AffineMapAttr::get(permutation));
-  return success();
+    if (!found)
+      return false;
+  }
+
+  return true;
 }
 
 LogicalResult TransposeOp::verify() {
-  if (!getPermutation().isPermutation())
+  ArrayRef<int64_t> permutation = getPermutation();
+
+  if (!isPermutationArray(permutation))
     return emitOpError("expected a permutation map");
-  if (getPermutation().getNumDims() != getIn().getType().getRank())
+  if (permutation.size() != getIn().getType().getRank())
     return emitOpError("expected a permutation map of same rank as the input");
 
   auto srcType = llvm::cast<MemRefType>(getIn().getType());
   auto dstType = llvm::cast<MemRefType>(getType());
-  auto transposedType = inferTransposeResultType(srcType, getPermutation());
+  auto transposedType = inferTransposeResultType(srcType, permutation);
   if (dstType != transposedType)
     return emitOpError("output type ")
            << dstType << " does not match transposed input type " << srcType
@@ -3259,6 +3267,15 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
   return {};
 }
 
+bool TransposeOp::isIdentity() {
+  ArrayRef<int64_t> permutationArray = getPermutation();
+  for (int64_t i = 0, rank = permutationArray.size(); i < rank; ++i)
+    if (permutationArray[i] != i)
+      return false;
+
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // ViewOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 9e44029ad93bd9c..355c9d494208212 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -243,7 +243,7 @@ func.func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) {
 //       CHECK:   llvm.extractvalue {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 //       CHECK:    llvm.insertvalue {{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 func.func @transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
-  %0 = memref.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
+  %0 = memref.transpose %arg0 [2, 0, 1] : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
   return
 }
 
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 6203cf1c76d144c..b909b46095f053e 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -68,11 +68,11 @@ func.func @fill_view(%arg0: memref<?xf32, strided<[1], offset: ?>>, %arg1: f32)
 // -----
 
 func.func @memref_transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
-  %0 = memref.transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
+  %0 = memref.transpose %arg0 [2, 1, 0] : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
   return
 }
 // CHECK-LABEL: func @memref_transpose
-//       CHECK:   memref.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) :
+//       CHECK:   memref.transpose %{{.*}} [2, 1, 0] :
 //  CHECK-SAME:      memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
 
 // -----
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index cb5977e302a993f..e4037b4bf0cdc2a 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -129,21 +129,21 @@ func.func @dma_wait_wrong_index_type(%tag : memref<2x2xi32>, %idx: index, %flt:
 
 func.func @transpose_not_permutation(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
   // expected-error @+1 {{expected a permutation map}}
-  memref.transpose %v (i, j) -> (i, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+  memref.transpose %v [1, 1] : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
 }
 
 // -----
 
 func.func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
   // expected-error @+1 {{expected a permutation map of same rank as the input}}
-  memref.transpose %v (i) -> (i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+  memref.transpose %v [0] : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
 }
 
 // -----
 
 func.func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
   // expected-error @+1 {{output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>'}}
-  memref.transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+  memref.transpose %v [1, 0] : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
 }
 
 // -----

``````````

</details>


https://github.com/llvm/llvm-project/pull/67880


More information about the Mlir-commits mailing list