[Mlir-commits] [mlir] [mlir][memref] Use array notation instead of permutation map for memref.transpose (PR #67880)
Felix Schneider
llvmlistbot at llvm.org
Sat Sep 30 23:54:05 PDT 2023
https://github.com/ubfx updated https://github.com/llvm/llvm-project/pull/67880
>From 64ca3c026abd19a78263dd9a70acccbdbaa5e397 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 30 Sep 2023 11:33:22 +0000
Subject: [PATCH 1/3] [mlir][memref] Use array notation instead of permutation
map for memref.transpose
Until now, the dimensional permutation for memref.transpose was given in the
form of an affine map. However, just from looking at such a representation,
e.g. `(i, j) -> (j, i)`, it's not obvious whether it represents a mapping from
the result dimensions to the source dimensions or the other way around. This has
led to a bug (#65145).
This patch introduces to `memref.transpose` the integer array based notation
that is also used in Ops like `linalg.transpose`, `memref.collapse_shape`
and others which is harder to misinterpret and easier to work with.
---
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 19 ++++-
.../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 31 ++++---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 83 +++++++++++--------
.../MemRefToLLVM/memref-to-llvm.mlir | 2 +-
mlir/test/Dialect/Linalg/roundtrip.mlir | 4 +-
mlir/test/Dialect/MemRef/invalid.mlir | 6 +-
6 files changed, 85 insertions(+), 60 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 6b0ccbe37e89e9c..067737d8a18fda9 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -2125,7 +2125,7 @@ def TensorStoreOp : MemRef_Op<"tensor_store",
def MemRef_TransposeOp : MemRef_Op<"transpose", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
Pure]>,
- Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
+ Arguments<(ins AnyStridedMemRef:$in, DenseI64ArrayAttr:$permutation)>,
Results<(outs AnyStridedMemRef)> {
let summary = "`transpose` produces a new strided memref (metadata-only)";
let description = [{
@@ -2133,22 +2133,33 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
are a permutation of the original `in` memref. This is purely a metadata
transformation.
+ The permutation is given in the form of an array of indices following the rule:
+ `dim(result, i) = dim(input, permutation[i])`
+
Example:
```mlir
- %1 = memref.transpose %0 (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
+ %1 = memref.transpose %0 [1, 0] : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
```
}];
let builders = [
- OpBuilder<(ins "Value":$in, "AffineMapAttr":$permutation,
+ OpBuilder<(ins "Value":$in, "DenseI64ArrayAttr":$permutation,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ OpBuilder<(ins "Value":$in, "ArrayRef<int64_t>":$permutation,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
let extraClassDeclaration = [{
static StringRef getPermutationAttrStrName() { return "permutation"; }
+
+ /// Returns true if the permutation represents an identity permutation
+ bool isIdentity();
+ }];
+
+ let assemblyFormat = [{
+ $in $permutation attr-dict `:` type($in) `to` type(results)
}];
- let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 61bd23f12601c79..2e34b690ae7a55f 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1410,38 +1410,35 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
MemRefDescriptor viewMemRef(adaptor.getIn());
// No permutation, early exit.
- if (transposeOp.getPermutation().isIdentity())
+ if (transposeOp.isIdentity())
return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
- auto targetMemRef = MemRefDescriptor::undef(
+ auto resultMemRef = MemRefDescriptor::undef(
rewriter, loc,
typeConverter->convertType(transposeOp.getIn().getType()));
// Copy the base and aligned pointers from the old descriptor to the new
// one.
- targetMemRef.setAllocatedPtr(rewriter, loc,
+ resultMemRef.setAllocatedPtr(rewriter, loc,
viewMemRef.allocatedPtr(rewriter, loc));
- targetMemRef.setAlignedPtr(rewriter, loc,
+ resultMemRef.setAlignedPtr(rewriter, loc,
viewMemRef.alignedPtr(rewriter, loc));
// Copy the offset pointer from the old descriptor to the new one.
- targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
+ resultMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
// Iterate over the dimensions and apply size/stride permutation:
- // When enumerating the results of the permutation map, the enumeration index
- // is the index into the target dimensions and the DimExpr points to the
- // dimension of the source memref.
- for (const auto &en :
- llvm::enumerate(transposeOp.getPermutation().getResults())) {
- int targetPos = en.index();
- int sourcePos = en.value().cast<AffineDimExpr>().getPosition();
- targetMemRef.setSize(rewriter, loc, targetPos,
- viewMemRef.size(rewriter, loc, sourcePos));
- targetMemRef.setStride(rewriter, loc, targetPos,
- viewMemRef.stride(rewriter, loc, sourcePos));
+ ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+ for (int64_t resultDimPos = 0, rank = permutation.size();
+ resultDimPos < rank; ++resultDimPos) {
+ int originalDimPos = permutation[resultDimPos];
+ resultMemRef.setSize(rewriter, loc, resultDimPos,
+ viewMemRef.size(rewriter, loc, originalDimPos));
+ resultMemRef.setStride(rewriter, loc, resultDimPos,
+ viewMemRef.stride(rewriter, loc, originalDimPos));
}
- rewriter.replaceOp(transposeOp, {targetMemRef});
+ rewriter.replaceOp(transposeOp, {resultMemRef});
return success();
}
};
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 215a8f5e7d18be0..fa28c850aea960a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3176,21 +3176,22 @@ void TransposeOp::getAsmResultNames(
setNameFn(getResult(), "transpose");
}
-/// Build a strided memref type by applying `permutationMap` tp `memRefType`.
+/// Build a strided memref type by applying `permutation` tp `memRefType`.
static MemRefType inferTransposeResultType(MemRefType memRefType,
- AffineMap permutationMap) {
+ ArrayRef<int64_t> permutation) {
auto rank = memRefType.getRank();
auto originalSizes = memRefType.getShape();
auto [originalStrides, offset] = getStridesAndOffset(memRefType);
assert(originalStrides.size() == static_cast<unsigned>(rank));
+ assert(permutation.size() == rank);
// Compute permuted sizes and strides.
SmallVector<int64_t> sizes(rank, 0);
SmallVector<int64_t> strides(rank, 1);
- for (const auto &en : llvm::enumerate(permutationMap.getResults())) {
- unsigned position = en.value().cast<AffineDimExpr>().getPosition();
- sizes[en.index()] = originalSizes[position];
- strides[en.index()] = originalStrides[position];
+ for (int64_t resultDimPos = 0; resultDimPos < rank; ++resultDimPos) {
+ int64_t originalDimPos = permutation[resultDimPos];
+ sizes[resultDimPos] = originalSizes[originalDimPos];
+ strides[resultDimPos] = originalStrides[originalDimPos];
}
return MemRefType::Builder(memRefType)
@@ -3200,52 +3201,59 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
}
void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
- AffineMapAttr permutation,
+ DenseI64ArrayAttr permutation,
ArrayRef<NamedAttribute> attrs) {
- auto permutationMap = permutation.getValue();
- assert(permutationMap);
-
auto memRefType = llvm::cast<MemRefType>(in.getType());
// Compute result type.
- MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
+ MemRefType resultType =
+ inferTransposeResultType(memRefType, permutation.asArrayRef());
build(b, result, resultType, in, attrs);
result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
}
-// transpose $in $permutation attr-dict : type($in) `to` type(results)
-void TransposeOp::print(OpAsmPrinter &p) {
- p << " " << getIn() << " " << getPermutation();
- p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
- p << " : " << getIn().getType() << " to " << getType();
+void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
+ ArrayRef<int64_t> permutation,
+ ArrayRef<NamedAttribute> attrs) {
+ auto memRefType = llvm::cast<MemRefType>(in.getType());
+ // Compute result type.
+ MemRefType resultType = inferTransposeResultType(memRefType, permutation);
+
+ build(b, result, resultType, in, attrs);
+ result.addAttribute(TransposeOp::getPermutationAttrStrName(),
+ b.getDenseI64ArrayAttr(permutation));
}
-ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
- OpAsmParser::UnresolvedOperand in;
- AffineMap permutation;
- MemRefType srcType, dstType;
- if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(srcType) ||
- parser.resolveOperand(in, srcType, result.operands) ||
- parser.parseKeywordType("to", dstType) ||
- parser.addTypeToList(dstType, result.types))
- return failure();
+/// Check whether the supplied array is an permutation index array, i.e. it
+/// contains the elements 0..size()-1.
+static bool isPermutationArray(ArrayRef<int64_t> arr) {
+ for (int64_t i = 0, e = arr.size(); i < e; ++i) {
+ bool found = false;
+ for (int64_t j = 0; j < e; ++j) {
+ if (arr[j] == i) {
+ found = true;
+ break;
+ }
+ }
- result.addAttribute(TransposeOp::getPermutationAttrStrName(),
- AffineMapAttr::get(permutation));
- return success();
+ if (!found)
+ return false;
+ }
+
+ return true;
}
LogicalResult TransposeOp::verify() {
- if (!getPermutation().isPermutation())
+ ArrayRef<int64_t> permutation = getPermutation();
+
+ if (!isPermutationArray(permutation))
return emitOpError("expected a permutation map");
- if (getPermutation().getNumDims() != getIn().getType().getRank())
+ if (permutation.size() != getIn().getType().getRank())
return emitOpError("expected a permutation map of same rank as the input");
auto srcType = llvm::cast<MemRefType>(getIn().getType());
auto dstType = llvm::cast<MemRefType>(getType());
- auto transposedType = inferTransposeResultType(srcType, getPermutation());
+ auto transposedType = inferTransposeResultType(srcType, permutation);
if (dstType != transposedType)
return emitOpError("output type ")
<< dstType << " does not match transposed input type " << srcType
@@ -3259,6 +3267,15 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
return {};
}
+bool TransposeOp::isIdentity() {
+ ArrayRef<int64_t> permutationArray = getPermutation();
+ for (int64_t i = 0, rank = permutationArray.size(); i < rank; ++i)
+ if (permutationArray[i] != i)
+ return false;
+
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// ViewOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index ae487ef6694745d..0d1b2219d6ac094 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -243,7 +243,7 @@ func.func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) {
// CHECK: llvm.extractvalue {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
func.func @transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
- %0 = memref.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
+ %0 = memref.transpose %arg0 [2, 0, 1] : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
return
}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 6203cf1c76d144c..b909b46095f053e 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -68,11 +68,11 @@ func.func @fill_view(%arg0: memref<?xf32, strided<[1], offset: ?>>, %arg1: f32)
// -----
func.func @memref_transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
- %0 = memref.transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
+ %0 = memref.transpose %arg0 [2, 1, 0] : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
return
}
// CHECK-LABEL: func @memref_transpose
-// CHECK: memref.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) :
+// CHECK: memref.transpose %{{.*}} [2, 1, 0] :
// CHECK-SAME: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
// -----
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index cb5977e302a993f..e4037b4bf0cdc2a 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -129,21 +129,21 @@ func.func @dma_wait_wrong_index_type(%tag : memref<2x2xi32>, %idx: index, %flt:
func.func @transpose_not_permutation(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
// expected-error @+1 {{expected a permutation map}}
- memref.transpose %v (i, j) -> (i, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+ memref.transpose %v [1, 1] : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
}
// -----
func.func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
// expected-error @+1 {{expected a permutation map of same rank as the input}}
- memref.transpose %v (i) -> (i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+ memref.transpose %v [0] : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
}
// -----
func.func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
// expected-error @+1 {{output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>'}}
- memref.transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
+ memref.transpose %v [1, 0] : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
}
// -----
>From b8a05d5c4c1a8b2a21cd007cabe8cafbbcdc884c Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 30 Sep 2023 13:47:36 +0000
Subject: [PATCH 2/3] typo
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index fa28c850aea960a..e2a1861046da870 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3176,7 +3176,7 @@ void TransposeOp::getAsmResultNames(
setNameFn(getResult(), "transpose");
}
-/// Build a strided memref type by applying `permutation` tp `memRefType`.
+/// Build a strided memref type by applying `permutation` to `memRefType`.
static MemRefType inferTransposeResultType(MemRefType memRefType,
ArrayRef<int64_t> permutation) {
auto rank = memRefType.getRank();
>From c33e4857e8572d7168457cfe78086f4893e77980 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sun, 1 Oct 2023 08:53:35 +0200
Subject: [PATCH 3/3] address review comments
---
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 2 +-
.../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 6 ++----
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 21 +++++++------------
mlir/test/Dialect/MemRef/invalid.mlir | 4 ++--
4 files changed, 13 insertions(+), 20 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 067737d8a18fda9..4ff3bcf27f628b8 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -2139,7 +2139,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
Example:
```mlir
- %1 = memref.transpose %0 [1, 0] : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
+ %1 = memref.transpose %0 [1, 2, 0] : memref<7x8x9xf32> to memref<8x9x7xf32>
```
}];
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 2e34b690ae7a55f..c290ac5e419f61e 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1428,10 +1428,8 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
resultMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
// Iterate over the dimensions and apply size/stride permutation:
- ArrayRef<int64_t> permutation = transposeOp.getPermutation();
- for (int64_t resultDimPos = 0, rank = permutation.size();
- resultDimPos < rank; ++resultDimPos) {
- int originalDimPos = permutation[resultDimPos];
+ for (auto [resultDimPos, originalDimPos] :
+ llvm::enumerate(transposeOp.getPermutation())) {
resultMemRef.setSize(rewriter, loc, resultDimPos,
viewMemRef.size(rewriter, loc, originalDimPos));
resultMemRef.setStride(rewriter, loc, resultDimPos,
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e2a1861046da870..d8e9cf864a35cf4 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3188,8 +3188,7 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
// Compute permuted sizes and strides.
SmallVector<int64_t> sizes(rank, 0);
SmallVector<int64_t> strides(rank, 1);
- for (int64_t resultDimPos = 0; resultDimPos < rank; ++resultDimPos) {
- int64_t originalDimPos = permutation[resultDimPos];
+ for (auto [resultDimPos, originalDimPos] : llvm::enumerate(permutation)) {
sizes[resultDimPos] = originalSizes[originalDimPos];
strides[resultDimPos] = originalStrides[originalDimPos];
}
@@ -3229,27 +3228,25 @@ void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
static bool isPermutationArray(ArrayRef<int64_t> arr) {
for (int64_t i = 0, e = arr.size(); i < e; ++i) {
bool found = false;
- for (int64_t j = 0; j < e; ++j) {
- if (arr[j] == i) {
+ for (auto dim : arr) {
+ if (dim == i) {
found = true;
break;
}
}
-
if (!found)
return false;
}
-
return true;
}
LogicalResult TransposeOp::verify() {
ArrayRef<int64_t> permutation = getPermutation();
-
if (!isPermutationArray(permutation))
- return emitOpError("expected a permutation map");
+ return emitOpError("expected a permutation array");
if (permutation.size() != getIn().getType().getRank())
- return emitOpError("expected a permutation map of same rank as the input");
+ return emitOpError(
+ "expected a permutation array of same size as the input rank");
auto srcType = llvm::cast<MemRefType>(getIn().getType());
auto dstType = llvm::cast<MemRefType>(getType());
@@ -3268,11 +3265,9 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) {
}
bool TransposeOp::isIdentity() {
- ArrayRef<int64_t> permutationArray = getPermutation();
- for (int64_t i = 0, rank = permutationArray.size(); i < rank; ++i)
- if (permutationArray[i] != i)
+ for (auto [index, dim] : llvm::enumerate(getPermutation()))
+ if (index != dim)
return false;
-
return true;
}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index e4037b4bf0cdc2a..883430d385fafe1 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -128,14 +128,14 @@ func.func @dma_wait_wrong_index_type(%tag : memref<2x2xi32>, %idx: index, %flt:
// -----
func.func @transpose_not_permutation(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
- // expected-error @+1 {{expected a permutation map}}
+ // expected-error @+1 {{expected a permutation array}}
memref.transpose %v [1, 1] : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
}
// -----
func.func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
- // expected-error @+1 {{expected a permutation map of same rank as the input}}
+ // expected-error @+1 {{expected a permutation array of same size as the input rank}}
memref.transpose %v [0] : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
}
More information about the Mlir-commits
mailing list