[Mlir-commits] [mlir] 823ffef - [mlir][Standard] Allow unranked memrefs as operands to dim and rank

Stephan Herhut llvmlistbot at llvm.org
Wed Jul 29 05:43:39 PDT 2020


Author: Stephan Herhut
Date: 2020-07-29T14:42:58+02:00
New Revision: 823ffef009152ba1210740c44d472d1d6e56afa3

URL: https://github.com/llvm/llvm-project/commit/823ffef009152ba1210740c44d472d1d6e56afa3
DIFF: https://github.com/llvm/llvm-project/commit/823ffef009152ba1210740c44d472d1d6e56afa3.diff

LOG: [mlir][Standard] Allow unranked memrefs as operands to dim and rank

`std.dim` currently only accepts ranked memrefs and `std.rank` is limited to
tensors.

Differential Revision: https://reviews.llvm.org/D84790

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/IR/invalid-ops.mlir
    mlir/test/Transforms/constant-fold.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 78307b8974763..d9634fa2b9e69 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1409,7 +1409,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
     ```
   }];
 
-  let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor],
+  let arguments = (ins AnyTypeOf<[AnyRankedOrUnrankedMemRef, AnyTensor],
                                  "any tensor or memref type">:$memrefOrTensor,
                        Index:$index);
   let results = (outs Index:$result);
@@ -2024,16 +2024,18 @@ def PrefetchOp : Std_Op<"prefetch"> {
 def RankOp : Std_Op<"rank", [NoSideEffect]> {
   let summary = "rank operation";
   let description = [{
-    The `rank` operation takes a tensor operand and returns its rank.
+    The `rank` operation takes a memref/tensor operand and returns its rank.
 
     Example:
 
     ```mlir
-    %1 = rank %0 : tensor<*xf32>
+    %1 = rank %arg0 : tensor<*xf32>
+    %2 = rank %arg1 : memref<*xf32>
     ```
   }];
 
-  let arguments = (ins AnyTensor);
+  let arguments = (ins AnyTypeOf<[AnyRankedOrUnrankedMemRef, AnyTensor],
+                                 "any tensor or memref type">:$memrefOrTensor);
   let results = (outs Index);
   let verifier = ?;
 
@@ -2044,7 +2046,7 @@ def RankOp : Std_Op<"rank", [NoSideEffect]> {
     }]>];
 
   let hasFolder = 1;
-  let assemblyFormat = "operands attr-dict `:` type(operands)";
+  let assemblyFormat = "$memrefOrTensor attr-dict `:` type($memrefOrTensor)";
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 84c35c9fb7a56..a67e79ac4a7ce 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2039,10 +2039,12 @@ LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
 //===----------------------------------------------------------------------===//
 
 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
-  // Constant fold rank when the rank of the tensor is known.
+  // Constant fold rank when the rank of the operand is known.
   auto type = getOperand().getType();
-  if (auto tensorType = type.dyn_cast<RankedTensorType>())
-    return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank());
+  if (auto shapedType = type.dyn_cast<ShapedType>())
+    if (shapedType.hasRank())
+      return IntegerAttr::get(IndexType::get(getContext()),
+                              shapedType.getRank());
   return IntegerAttr();
 }
 

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 3668c25253adc..6302a8a4acbf9 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -10,7 +10,7 @@ func @dim(%arg : tensor<1x?xf32>) {
 
 func @rank(f32) {
 ^bb(%0: f32):
-  "std.rank"(%0): (f32)->index // expected-error {{'std.rank' op operand #0 must be tensor of any type values}}
+  "std.rank"(%0): (f32)->index // expected-error {{'std.rank' op operand #0 must be any tensor or memref type}}
   return
 }
 

diff  --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir
index 0677b95723701..36fa234213ea1 100644
--- a/mlir/test/Transforms/constant-fold.mlir
+++ b/mlir/test/Transforms/constant-fold.mlir
@@ -686,6 +686,18 @@ func @fold_rank() -> (index) {
 
 // -----
 
+// CHECK-LABEL: func @fold_rank_memref
+func @fold_rank_memref(%arg0 : memref<?x?xf32>) -> (index) {
+  // Fold a rank into a constant
+  // CHECK-NEXT: [[C2:%.+]] = constant 2 : index
+  %rank_0 = rank %arg0 : memref<?x?xf32>
+
+  // CHECK-NEXT: return [[C2]]
+  return %rank_0 : index
+}
+
+// -----
+
 // CHECK-LABEL: func @nested_isolated_region
 func @nested_isolated_region() {
   // CHECK-NEXT: func @isolated_op


        


More information about the Mlir-commits mailing list