[Mlir-commits] [mlir] 823ffef - [mlir][Standard] Allow unranked memrefs as operands to dim and rank
Stephan Herhut
llvmlistbot at llvm.org
Wed Jul 29 05:43:39 PDT 2020
Author: Stephan Herhut
Date: 2020-07-29T14:42:58+02:00
New Revision: 823ffef009152ba1210740c44d472d1d6e56afa3
URL: https://github.com/llvm/llvm-project/commit/823ffef009152ba1210740c44d472d1d6e56afa3
DIFF: https://github.com/llvm/llvm-project/commit/823ffef009152ba1210740c44d472d1d6e56afa3.diff
LOG: [mlir][Standard] Allow unranked memrefs as operands to dim and rank
`std.dim` currently only accepts ranked memrefs and `std.rank` is limited to
tensors.
Differential Revision: https://reviews.llvm.org/D84790
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/IR/invalid-ops.mlir
mlir/test/Transforms/constant-fold.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 78307b8974763..d9634fa2b9e69 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1409,7 +1409,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
```
}];
- let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor],
+ let arguments = (ins AnyTypeOf<[AnyRankedOrUnrankedMemRef, AnyTensor],
"any tensor or memref type">:$memrefOrTensor,
Index:$index);
let results = (outs Index:$result);
@@ -2024,16 +2024,18 @@ def PrefetchOp : Std_Op<"prefetch"> {
def RankOp : Std_Op<"rank", [NoSideEffect]> {
let summary = "rank operation";
let description = [{
- The `rank` operation takes a tensor operand and returns its rank.
+ The `rank` operation takes a memref/tensor operand and returns its rank.
Example:
```mlir
- %1 = rank %0 : tensor<*xf32>
+ %1 = rank %arg0 : tensor<*xf32>
+ %2 = rank %arg1 : memref<*xf32>
```
}];
- let arguments = (ins AnyTensor);
+ let arguments = (ins AnyTypeOf<[AnyRankedOrUnrankedMemRef, AnyTensor],
+ "any tensor or memref type">:$memrefOrTensor);
let results = (outs Index);
let verifier = ?;
@@ -2044,7 +2046,7 @@ def RankOp : Std_Op<"rank", [NoSideEffect]> {
}]>];
let hasFolder = 1;
- let assemblyFormat = "operands attr-dict `:` type(operands)";
+ let assemblyFormat = "$memrefOrTensor attr-dict `:` type($memrefOrTensor)";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 84c35c9fb7a56..a67e79ac4a7ce 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2039,10 +2039,12 @@ LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
//===----------------------------------------------------------------------===//
OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
- // Constant fold rank when the rank of the tensor is known.
+ // Constant fold rank when the rank of the operand is known.
auto type = getOperand().getType();
- if (auto tensorType = type.dyn_cast<RankedTensorType>())
- return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank());
+ if (auto shapedType = type.dyn_cast<ShapedType>())
+ if (shapedType.hasRank())
+ return IntegerAttr::get(IndexType::get(getContext()),
+ shapedType.getRank());
return IntegerAttr();
}
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 3668c25253adc..6302a8a4acbf9 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -10,7 +10,7 @@ func @dim(%arg : tensor<1x?xf32>) {
func @rank(f32) {
^bb(%0: f32):
- "std.rank"(%0): (f32)->index // expected-error {{'std.rank' op operand #0 must be tensor of any type values}}
+ "std.rank"(%0): (f32)->index // expected-error {{'std.rank' op operand #0 must be any tensor or memref type}}
return
}
diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir
index 0677b95723701..36fa234213ea1 100644
--- a/mlir/test/Transforms/constant-fold.mlir
+++ b/mlir/test/Transforms/constant-fold.mlir
@@ -686,6 +686,18 @@ func @fold_rank() -> (index) {
// -----
+// CHECK-LABEL: func @fold_rank_memref
+func @fold_rank_memref(%arg0 : memref<?x?xf32>) -> (index) {
+ // Fold a rank into a constant
+ // CHECK-NEXT: [[C2:%.+]] = constant 2 : index
+ %rank_0 = rank %arg0 : memref<?x?xf32>
+
+ // CHECK-NEXT: return [[C2]]
+ return %rank_0 : index
+}
+
+// -----
+
// CHECK-LABEL: func @nested_isolated_region
func @nested_isolated_region() {
// CHECK-NEXT: func @isolated_op
More information about the Mlir-commits
mailing list