[Mlir-commits] [mlir] [mlir][memref] Use array notation instead of permutation map for memref.transpose (PR #67880)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Sep 30 04:37:58 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
<details>
<summary>Changes</summary>
Until now, the dimensional permutation for memref.transpose was given in the form of an affine map. However, just from looking at such a representation, e.g. `(i, j) -> (j, i)`, it's not obvious whether it represents a mapping from the result dimensions to the source dimensions or the other way around. This has led to a bug (#<!-- -->65145).
This patch introduces to `memref.transpose` the integer array based notation that is also used in Ops like `linalg.transpose`, `memref.collapse_shape` and others which is harder to misinterpret and easier to work with.
---
Full diff: https://github.com/llvm/llvm-project/pull/67880.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+15-4)
- (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+14-17)
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+50-33)
- (modified) mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir (+1-1)
- (modified) mlir/test/Dialect/Linalg/roundtrip.mlir (+2-2)
- (modified) mlir/test/Dialect/MemRef/invalid.mlir (+3-3)
``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index ea6e363a6c3257f..30eb3feb097bf81 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -2119,7 +2119,7 @@ def TensorStoreOp : MemRef_Op<"tensor_store",
def MemRef_TransposeOp : MemRef_Op<"transpose", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
Pure]>,
- Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
+ Arguments<(ins AnyStridedMemRef:$in, DenseI64ArrayAttr:$permutation)>,
Results<(outs AnyStridedMemRef)> {
let summary = "`transpose` produces a new strided memref (metadata-only)";
let description = [{
@@ -2127,22 +2127,33 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
are a permutation of the original `in` memref. This is purely a metadata
transformation.
+ The permutation is given in the form of an array of indices following the rule:
+ `dim(result, i) = dim(input, permutation[i])`
+
Example:
```mlir
- %1 = memref.transpose %0 (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
+ %1 = memref.transpose %0 [1, 0] : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
```
}];
let builders = [
- OpBuilder<(ins "Value":$in, "AffineMapAttr":$permutation,
+ OpBuilder<(ins "Value":$in, "DenseI64ArrayAttr":$permutation,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ OpBuilder<(ins "Value":$in, "ArrayRef<int64_t>":$permutation,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
let extraClassDeclaration = [{
static StringRef getPermutationAttrStrName() { return "permutation"; }
+
+ /// Returns true if the permutation represents an identity permutation
+ bool isIdentity();
+ }];
+
+ let assemblyFormat = [{
+ $in $permutation attr-dict `:` type($in) `to` type(results)
}];
- let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 61bd23f12601c79..2e34b690ae7a55f 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1410,38 +1410,35 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
MemRefDescriptor viewMemRef(adaptor.getIn());
// No permutation, early exit.
- if (transposeOp.getPermutation().isIdentity())
+ if (transposeOp.isIdentity())
return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
- auto targetMemRef = MemRefDescriptor::undef(
+ auto resultMemRef = MemRefDescriptor::undef(
rewriter, loc,
typeConverter->convertType(transposeOp.getIn().getType()));
// Copy the base and aligned pointers from the old descriptor to the new
// one.
- targetMemRef.setAllocatedPtr(rewriter, loc,
+ resultMemRef.setAllocatedPtr(rewriter, loc,
viewMemRef.allocatedPtr(rewriter, loc));
- targetMemRef.setAlignedPtr(rewriter, loc,
+ resultMemRef.setAlignedPtr(rewriter, loc,
viewMemRef.alignedPtr(rewriter, loc));
// Copy the offset pointer from the old descriptor to the new one.
- targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
+ resultMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
// Iterate over the dimensions and apply size/stride permutation:
- // When enumerating the results of the permutation map, the enumeration index
- // is the index into the target dimensions and the DimExpr points to the
- // dimension of the source memref.
- for (const auto &en :
- llvm::enumerate(transposeOp.getPermutation().getResults())) {
- int targetPos = en.index();
- int sourcePos = en.value().cast<AffineDimExpr>().getPosition();
- targetMemRef.setSize(rewriter, loc, targetPos,
- viewMemRef.size(rewriter, loc, sourcePos));
- targetMemRef.setStride(rewriter, loc, targetPos,
- viewMemRef.stride(rewriter, loc, sourcePos));
+ ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+ for (int64_t resultDimPos = 0, rank = permutation.size();
+ resultDimPos < rank; ++resultDimPos) {
+ int originalDimPos = permutation[resultDimPos];
+ resultMemRef.setSize(rewriter, loc, resultDimPos,
+ viewMemRef.size(rewriter, loc, originalDimPos));
+ resultMemRef.setStride(rewriter, loc, resultDimPos,
+ viewMemRef.stride(rewriter, loc, originalDimPos));
}
- rewriter.replaceOp(transposeOp, {targetMemRef});
+ rewriter.replaceOp(transposeOp, {resultMemRef});
return success();
}
};
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 215a8f5e7d18be0..fa28c850aea960a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3176,21 +3176,22 @@ void TransposeOp::getAsmResultNames(
setNameFn(getResult(), "transpose");
}
-/// Build a strided memref type by applying `permutationMap` tp `memRefType`.
+/// Build a strided memref type by applying `permutation` tp `memRefType`.
static MemRefType inferTransposeResultType(MemRefType memRefType,
- AffineMap permutationMap) {
+ ArrayRef<int64_t> permutation) {
auto rank = memRefType.getRank();
auto originalSizes = memRefType.getShape();
auto [originalStrides, offset] = getStridesAndOffset(memRefType);
assert(originalStrides.size() == static_cast<unsigned>(rank));
+ assert(permutation.size() == rank);
// Compute permuted sizes and strides.
SmallVector<int64_t> sizes(rank, 0);
SmallVector<int64_t> strides(rank, 1);
- for (const auto &en : llvm::enumerate(permutationMap.getResults())) {
- unsigned position = en.value().cast<AffineDimExpr>().getPosition();
- sizes[en.index()] = originalSizes[position];
- strides[en.index()] = originalStrides[position];
+ for (int64_t resultDimPos = 0; resultDimPos < rank; ++resultDimPos) {
+ int64_t originalDimPos = permutation[resultDimPos];
+ sizes[resultDimPos] = originalSizes[originalDimPos];
+ strides[resultDimPos] = originalStrides[originalDimPos];
}
return MemRefType::Builder(memRefType)
@@ -3200,52 +3201,59 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
}
void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
- AffineMapAttr permutation,
+ DenseI64ArrayAttr permutation,
ArrayRef<NamedAttribute> attrs) {
- auto permutationMap = permutation.getValue();
- assert(permutationMap);
-
auto memRefType = llvm::cast<MemRefType>(in.getType());
// Compute result type.
- MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
+ MemRefType resultType =
+ inferTransposeResultType(memRefType, permutation.asArrayRef());
build(b, result, resultType, in, attrs);
result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
}
-// transpose $in $permutation attr-dict : type($in) `to` type(results)
-void TransposeOp::print(OpAsmPrinter &p) {
- p << " " << getIn() << " " << getPermutation();
- p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
- p << " : " << getIn().getType() << " to " << getType();
+void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
+ ArrayRef<int64_t> permutation,
+ ArrayRef<NamedAttribute> attrs) {
+ auto memRefType = llvm::cast<MemRefType>(in.getType());
+ // Compute result type.
+ MemRefType resultType = inferTransposeResultType(memRefType, permutation);
+
+ build(b, result, resultType, in, attrs);
+ result.addAttribute(TransposeOp::getPermutationAttrStrName(),
+ b.getDenseI64ArrayAttr(permutation));
}
-ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
- OpAsmParser::UnresolvedOperand in;
- AffineMap permutation;
- MemRefType srcType, dstType;
- if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(srcType) ||
- parser.resolveOperand(in, srcType, result.operands) ||
- parser.parseKeywordType("to", dstType) ||
- parser.addTypeToList(dstType, result.types))
- return failure();
+/// Check whether the supplied array is an permutation index array, i.e. it
+/// contains the elements 0..size()-1.
+static bool isPermutationArray(ArrayRef<int64_t> arr) {
+ for (int64_t i = 0, e = arr.size(); i < e; ++i) {
+ bool found = false;
+ for (int64_t j = 0; j < e; ++j) {
+ if (arr[j] == i) {
+ found = true;
+ break;
+ }
+ }
- result.addAttribute(TransposeOp::getPermutationAttrStrName(),
- AffineMapAttr::get(permutation));
- return success();
+ if (!found)
+ return false;
+ }
+
+ return true;
}
LogicalResult TransposeOp::verify() {
- if (!getPermutation().isPermutation())
+ ArrayRef<int64_t> permutation = getPermutation();
+
+ if (!isPermutationArray(permutation))
return emitOpError("expected a permutation map");
- if (getPermutation().getNumDims() != getIn().getType().getRank())
+ if (permutation.size() != getIn().getType().getRank())
return emitOpError("expected a permutation map of same rank as the input");
auto srcType = llvm::cast<MemRefType>(getIn().getType());
auto dstType = llvm::cast<MemRefType>(getType());
- auto transposedType = inferTransposeResultType(srcType, getPermutation());
+ auto transposedType = inferTransposeResultType(srcType, permutation);
if (dstType != transposedType)
return emitOpError("output type ")
<< dstType << " does not match transposed input type " << srcType
@@ -3259,6 +3267,15 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
return {};
}
+bool TransposeOp::isIdentity() {
+ ArrayRef<int64_t> permutationArray = getPermutation();
+ for (int64_t i = 0, rank = permutationArray.size(); i < rank; ++i)
+ if (permutationArray[i] != i)
+ return false;
+
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// ViewOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 9e44029ad93bd9c..355c9d494208212 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -243,7 +243,7 @@ func.func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) {
// CHECK: llvm.extractvalue {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
func.func @transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
- %0 = memref.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
+ %0 = memref.transpose %arg0 [2, 0, 1] : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
return
}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 6203cf1c76d144c..b909b46095f053e 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -68,11 +68,11 @@ func.func @fill_view(%arg0: memref<?xf32, strided<[1], offset: ?>>, %arg1: f32)
// -----
func.func @memref_transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
- %0 = memref.transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
+ %0 = memref.transpose %arg0 [2, 1, 0] : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
return
}
// CHECK-LABEL: func @memref_transpose
-// CHECK: memref.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) :
+// CHECK: memref.transpose %{{.*}} [2, 1, 0] :
// CHECK-SAME: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
// -----
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index cb5977e302a993f..e4037b4bf0cdc2a 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -129,21 +129,21 @@ func.func @dma_wait_wrong_index_type(%tag : memref<2x2xi32>, %idx: index, %flt:
func.func @transpose_not_permutation(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
// expected-error @+1 {{expected a permutation map}}
- memref.transpose %v (i, j) -> (i, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+ memref.transpose %v [1, 1] : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
}
// -----
func.func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
// expected-error @+1 {{expected a permutation map of same rank as the input}}
- memref.transpose %v (i) -> (i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+ memref.transpose %v [0] : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
}
// -----
func.func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
// expected-error @+1 {{output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>'}}
- memref.transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+ memref.transpose %v [1, 0] : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
}
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/67880
More information about the Mlir-commits
mailing list