[Mlir-commits] [mlir] 15f8f3e - [mlir] Split std.rank into tensor.rank and memref.rank.

Alexander Belyaev llvmlistbot at llvm.org
Tue Dec 14 01:16:21 PST 2021


Author: Alexander Belyaev
Date: 2021-12-14T10:15:55+01:00
New Revision: 15f8f3e20aa92349b0cb559d657f7648987edb06

URL: https://github.com/llvm/llvm-project/commit/15f8f3e20aa92349b0cb559d657f7648987edb06
DIFF: https://github.com/llvm/llvm-project/commit/15f8f3e20aa92349b0cb559d657f7648987edb06.diff

LOG: [mlir] Split std.rank into tensor.rank and memref.rank.

Move `std.rank` similarly to how `std.dim` was moved to TensorOps and MemRefOps.

Differential Revision: https://reviews.llvm.org/D115665

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Transforms/BufferOptimizations.cpp
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
    mlir/test/Dialect/MemRef/canonicalize.mlir
    mlir/test/Dialect/MemRef/invalid.mlir
    mlir/test/Dialect/MemRef/ops.mlir
    mlir/test/Dialect/Tensor/canonicalize.mlir
    mlir/test/Dialect/Tensor/invalid.mlir
    mlir/test/Dialect/Tensor/ops.mlir
    mlir/test/IR/core-ops.mlir
    mlir/test/IR/invalid-ops.mlir
    mlir/test/Transforms/constant-fold.mlir
    mlir/test/Transforms/promote-buffers-to-stack.mlir

Removed: 
    mlir/test/Conversion/StandardToLLVM/rank.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c6b7e28fe0aa8..e529a50dae935 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -998,6 +998,31 @@ def MemRef_ReinterpretCastOp:
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// RankOp
+//===----------------------------------------------------------------------===//
+
+def MemRef_RankOp : MemRef_Op<"rank", [NoSideEffect]> {
+  let summary = "rank operation";
+  let description = [{
+    The `memref.rank` operation takes a memref operand and returns its rank.
+
+    Example:
+
+    ```mlir
+    %0 = memref.rank %arg0 : memref<*xf32>
+    %1 = memref.rank %arg1 : memref<?x?xf32>
+    ```
+  }];
+
+  let arguments = (ins AnyRankedOrUnrankedMemRef:$memref);
+  let results = (outs Index);
+
+  let verifier = ?;
+  let hasFolder = 1;
+  let assemblyFormat = "$memref attr-dict `:` type($memref)";
+}
+
 //===----------------------------------------------------------------------===//
 // ReshapeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 23b9df282af03..2e50971db9e7a 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -658,32 +658,6 @@ def ConstantOp : Std_Op<"constant",
   let hasFolder = 1;
 }
 
-//===----------------------------------------------------------------------===//
-// RankOp
-//===----------------------------------------------------------------------===//
-
-def RankOp : Std_Op<"rank", [NoSideEffect]> {
-  let summary = "rank operation";
-  let description = [{
-    The `rank` operation takes a memref/tensor operand and returns its rank.
-
-    Example:
-
-    ```mlir
-    %1 = rank %arg0 : tensor<*xf32>
-    %2 = rank %arg1 : memref<*xf32>
-    ```
-  }];
-
-  let arguments = (ins AnyTypeOf<[AnyRankedOrUnrankedMemRef, AnyTensor],
-                                 "any memref or tensor type">:$memrefOrTensor);
-  let results = (outs Index);
-  let verifier = ?;
-
-  let hasFolder = 1;
-  let assemblyFormat = "$memrefOrTensor attr-dict `:` type($memrefOrTensor)";
-}
-
 //===----------------------------------------------------------------------===//
 // ReturnOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 3b1bfeeca6c10..21331fc649cd5 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -68,9 +68,9 @@ def Tensor_CastOp : Tensor_Op<"cast", [
 def Tensor_DimOp : Tensor_Op<"dim", [NoSideEffect]> {
   let summary = "dimension index operation";
   let description = [{
-    The `dim` operation takes a tensor and a dimension operand of type `index`.
-    It returns the size of the requested dimension of the given tensor.
-    If the dimension index is out of bounds, the behavior is undefined.
+    The `tensor.dim` operation takes a tensor and a dimension operand of type
+    `index`. It returns the size of the requested dimension of the given
+    tensor. If the dimension index is out of bounds, the behavior is undefined.
 
     The specified tensor type is that of the first operand.
 
@@ -558,6 +558,31 @@ def Tensor_InsertSliceOp : BaseOpWithOffsetSizesAndStrides<
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// RankOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_RankOp : Tensor_Op<"rank", [NoSideEffect]> {
+  let summary = "rank operation";
+  let description = [{
+    The `tensor.rank` operation takes a tensor operand and returns its rank.
+
+    Example:
+
+    ```mlir
+    %0 = tensor.rank %arg0 : tensor<*xf32>
+    %1 = tensor.rank %arg1 : tensor<?x?xf32>
+    ```
+  }];
+
+  let arguments = (ins AnyTensor:$tensor);
+  let results = (outs Index);
+
+  let verifier = ?;
+  let hasFolder = 1;
+  let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
+}
+
 //===----------------------------------------------------------------------===//
 // ReshapeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 521b3fcab0c6f..28981dd87ecc9 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -596,6 +596,28 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
   }
 };
 
+struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
+  using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Type operandType = op.memref().getType();
+    if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
+      UnrankedMemRefDescriptor desc(adaptor.memref());
+      rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
+      return success();
+    }
+    if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
+      rewriter.replaceOp(
+          op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
+      return success();
+    }
+    return failure();
+  }
+};
+
 struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
   using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
 
@@ -1549,6 +1571,7 @@ void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
       MemRefReinterpretCastOpLowering,
       MemRefReshapeOpLowering,
       PrefetchOpLowering,
+      RankOpLowering,
       ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
       ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
       StoreOpLowering,

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index e1e24faa4d2a6..5a1af7b33132e 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -577,7 +577,7 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
 
   // Lower to `tensor.generate` otherwise.
   auto *ctx = rewriter.getContext();
-  Value rank = rewriter.create<mlir::RankOp>(loc, tensor);
+  Value rank = rewriter.create<tensor::RankOp>(loc, tensor);
   rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
       op, getExtentTensorType(ctx), ValueRange{rank},
       [&](OpBuilder &b, Location loc, ValueRange args) {

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 200834a2d1bc6..f588521ac6ef0 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -566,28 +566,6 @@ struct UnrealizedConversionCastOpLowering
   }
 };
 
-struct RankOpLowering : public ConvertOpToLLVMPattern<RankOp> {
-  using ConvertOpToLLVMPattern<RankOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(RankOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    Type operandType = op.getMemrefOrTensor().getType();
-    if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
-      UnrankedMemRefDescriptor desc(adaptor.getMemrefOrTensor());
-      rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
-      return success();
-    }
-    if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
-      rewriter.replaceOp(
-          op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
-      return success();
-    }
-    return failure();
-  }
-};
-
 // Common base for load and store operations on MemRefs.  Restricts the match
 // to supported MemRef types. Provides functionality to emit code accessing a
 // specific element of the underlying data buffer.
@@ -987,7 +965,6 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
       CondBranchOpLowering,
       ConstantOpLowering,
       GenericAtomicRMWOpLowering,
-      RankOpLowering,
       ReturnOpLowering,
       SelectOpLowering,
       SplatOpLowering,

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 4badc0b31ddb6..1916ffe36dd66 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1072,6 +1072,19 @@ LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
   return foldMemRefCast(*this);
 }
 
+//===----------------------------------------------------------------------===//
+// RankOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
+  // Constant fold rank when the rank of the operand is known.
+  auto type = getOperand().getType();
+  auto shapedType = type.dyn_cast<ShapedType>();
+  if (shapedType && shapedType.hasRank())
+    return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
+  return IntegerAttr();
+}
+
 //===----------------------------------------------------------------------===//
 // ReinterpretCastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 1a43c0937d038..1d045b2912154 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -899,20 +899,6 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) {
   return value.isa<UnitAttr>();
 }
 
-//===----------------------------------------------------------------------===//
-// RankOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
-  // Constant fold rank when the rank of the operand is known.
-  auto type = getOperand().getType();
-  if (auto shapedType = type.dyn_cast<ShapedType>())
-    if (shapedType.hasRank())
-      return IntegerAttr::get(IndexType::get(getContext()),
-                              shapedType.getRank());
-  return IntegerAttr();
-}
-
 //===----------------------------------------------------------------------===//
 // ReturnOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index edddfb86e5539..ecdd966a3c35e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -609,6 +609,19 @@ void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
               StaticTensorGenerate>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// RankOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
+  // Constant fold rank when the rank of the operand is known.
+  auto type = getOperand().getType();
+  auto shapedType = type.dyn_cast<ShapedType>();
+  if (shapedType && shapedType.hasRank())
+    return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
+  return IntegerAttr();
+}
+
 //===----------------------------------------------------------------------===//
 // ReshapeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/BufferOptimizations.cpp b/mlir/lib/Transforms/BufferOptimizations.cpp
index 64a005dfb55b1..27e00a14c0d4d 100644
--- a/mlir/lib/Transforms/BufferOptimizations.cpp
+++ b/mlir/lib/Transforms/BufferOptimizations.cpp
@@ -37,14 +37,16 @@ static bool defaultIsSmallAlloc(Value alloc, unsigned maximumSizeInBytes,
   if (!type || !alloc.getDefiningOp<memref::AllocOp>())
     return false;
   if (!type.hasStaticShape()) {
-    // Check if the dynamic shape dimension of the alloc is produced by RankOp.
-    // If this is the case, it is likely to be small. Furthermore, the dimension
-    // is limited to the maximum rank of the allocated memref to avoid large
-    // values by multiplying several small values.
+    // Check if the dynamic shape dimension of the alloc is produced by
+    // `memref.rank`. If this is the case, it is likely to be small.
+    // Furthermore, the dimension is limited to the maximum rank of the
+    // allocated memref to avoid large values by multiplying several small
+    // values.
     if (type.getRank() <= maxRankOfAllocatedMemRef) {
-      return llvm::all_of(
-          alloc.getDefiningOp()->getOperands(),
-          [&](Value operand) { return operand.getDefiningOp<RankOp>(); });
+      return llvm::all_of(alloc.getDefiningOp()->getOperands(),
+                          [&](Value operand) {
+                            return operand.getDefiningOp<memref::RankOp>();
+                          });
     }
     return false;
   }

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index a26638a34151f..009106f95e8a5 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -1,7 +1,6 @@
 // RUN: mlir-opt -convert-memref-to-llvm %s -split-input-file | FileCheck %s
 // RUN: mlir-opt -convert-memref-to-llvm='index-bitwidth=32' %s -split-input-file | FileCheck --check-prefix=CHECK32 %s
 
-
 // CHECK-LABEL: func @view(
 // CHECK: %[[ARG0F:.*]]: index, %[[ARG1F:.*]]: index, %[[ARG2F:.*]]: index
 func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
@@ -835,3 +834,24 @@ func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> {
 // CHECK:           llvm.mul %{{.*}}, %{{.*}}  : i64
 // CHECK:           llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           llvm.mul %{{.*}}, %{{.*}}  : i64
+
+// -----
+
+// CHECK-LABEL: func @rank_of_unranked
+// CHECK32-LABEL: func @rank_of_unranked
+func @rank_of_unranked(%unranked: memref<*xi32>) {
+  %rank = memref.rank %unranked : memref<*xi32>
+  return
+}
+// CHECK: %[[UNRANKED_DESC:.*]] = builtin.unrealized_conversion_cast
+// CHECK-NEXT: llvm.extractvalue %[[UNRANKED_DESC]][0] : !llvm.struct<(i64, ptr<i8>)>
+// CHECK32: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(i32, ptr<i8>)>
+
+// CHECK-LABEL: func @rank_of_ranked
+// CHECK32-LABEL: func @rank_of_ranked
+func @rank_of_ranked(%ranked: memref<?xi32>) {
+  %rank = memref.rank %ranked : memref<?xi32>
+  return
+}
+// CHECK: llvm.mlir.constant(1 : index) : i64
+// CHECK32: llvm.mlir.constant(1 : index) : i32

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 015cb2fcaaf43..ea0ef33862ce5 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -203,7 +203,7 @@ func @shape_of(%arg : tensor<*xf32>) {
 // CHECK-LABEL: @shape_of_unranked
 // CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
 func @shape_of_unranked(%arg : tensor<*xf32>) {
-  // CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
+  // CHECK: %[[RANK:.*]] = tensor.rank %[[ARG]] : tensor<*xf32>
   // CHECK: %[[SHAPE:.*]] = tensor.generate %[[RANK]] {
   // CHECK: ^bb0(%[[I:.*]]: index):
   // CHECK:   %[[EXTENT:.*]] = tensor.dim %[[ARG]], %[[I]] : tensor<*xf32>

diff  --git a/mlir/test/Conversion/StandardToLLVM/rank.mlir b/mlir/test/Conversion/StandardToLLVM/rank.mlir
deleted file mode 100644
index 7c0a03aa8df37..0000000000000
--- a/mlir/test/Conversion/StandardToLLVM/rank.mlir
+++ /dev/null
@@ -1,23 +0,0 @@
-// RUN: mlir-opt -convert-std-to-llvm %s -split-input-file | FileCheck %s
-// RUN: mlir-opt -convert-std-to-llvm='index-bitwidth=32' %s -split-input-file | FileCheck --check-prefix=CHECK32 %s
-
-// CHECK-LABEL: func @rank_of_unranked
-// CHECK32-LABEL: func @rank_of_unranked
-func @rank_of_unranked(%unranked: memref<*xi32>) {
-  %rank = rank %unranked : memref<*xi32>
-  return
-}
-// CHECK-NEXT: llvm.mlir.undef
-// CHECK-NEXT: llvm.insertvalue
-// CHECK-NEXT: llvm.insertvalue
-// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(i64, ptr<i8>)>
-// CHECK32: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(i32, ptr<i8>)>
-
-// CHECK-LABEL: func @rank_of_ranked
-// CHECK32-LABEL: func @rank_of_ranked
-func @rank_of_ranked(%ranked: memref<?xi32>) {
-  %rank = rank %ranked : memref<?xi32>
-  return
-}
-// CHECK: llvm.mlir.constant(1 : index) : i64
-// CHECK32: llvm.mlir.constant(1 : index) : i32

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 251658fac7653..80282c21afab0 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -185,10 +185,10 @@ func @dim_of_alloca(%size: index) -> index {
 // Test case: Folding of memref.dim(memref.alloca(rank(%v)), %idx) -> rank(%v)
 // CHECK-LABEL: func @dim_of_alloca_with_dynamic_size(
 //  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>
-//  CHECK-NEXT:   %[[RANK:.*]] = rank %[[MEM]] : memref<*xf32>
+//  CHECK-NEXT:   %[[RANK:.*]] = memref.rank %[[MEM]] : memref<*xf32>
 //  CHECK-NEXT:   return %[[RANK]] : index
 func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
-  %0 = rank %arg0 : memref<*xf32>
+  %0 = memref.rank %arg0 : memref<*xf32>
   %1 = memref.alloca(%0) : memref<?xindex>
   %c0 = arith.constant 0 : index
   %2 = memref.dim %1, %c0 : memref<?xindex>
@@ -438,3 +438,15 @@ func @reduced_memref(%arg0: memref<2x5x7x1xf32>, %arg1 :index)
 //       CHECK:   %[[RESULT:.+]] = memref.subview
 //  CHECK-SAME:       memref<2x5x7x1xf32> to memref<1x4x1xf32, #{{.+}}>
 //       CHECK:   return %[[RESULT]]
+
+// -----
+
+// CHECK-LABEL: func @fold_rank_memref
+func @fold_rank_memref(%arg0 : memref<?x?xf32>) -> (index) {
+  // Fold a rank into a constant
+  // CHECK-NEXT: [[C2:%.+]] = arith.constant 2 : index
+  %rank_0 = memref.rank %arg0 : memref<?x?xf32>
+
+  // CHECK-NEXT: return [[C2]]
+  return %rank_0 : index
+}

diff  --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 6014687e6e9dd..55c5a821fb3dd 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -844,3 +844,11 @@ func @test_alloc_memref_map_rank_mismatch() {
   %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0) -> (d0)>, 1>
   return
 }
+
+// -----
+
+func @rank(%0: f32) {
+  // expected-error at +1 {{'memref.rank' op operand #0 must be unranked.memref of any type values or memref of any type values}}
+  "memref.rank"(%0): (f32)->index
+  return
+}

diff  --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index f716c5de21742..4ff2f8b5517be 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -207,3 +207,14 @@ func @collapse_shape_to_dynamic
 //      CHECK: func @collapse_shape_to_dynamic
 //      CHECK:   memref.collapse_shape
 // CHECK-SAME:    [0], [1], [2, 3, 4]
+
+// -----
+
+func @rank(%t : memref<4x4x?xf32>) {
+  // CHECK: %{{.*}} = memref.rank %{{.*}} : memref<4x4x?xf32>
+  %0 = "memref.rank"(%t) : (memref<4x4x?xf32>) -> index
+
+  // CHECK: %{{.*}} = memref.rank %{{.*}} : memref<4x4x?xf32>
+  %1 = memref.rank %t : memref<4x4x?xf32>
+  return
+}

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index fc9abe439b8a2..ec9601e269939 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -183,7 +183,7 @@ func @extract_oob_from_tensor.from_elements(%element : index) -> index {
 // CHECK-LABEL: func @extract_from_tensor.generate
 // CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
 func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index {
-  %size = rank %tensor : tensor<*xf32>
+  %size = tensor.rank %tensor : tensor<*xf32>
   // CHECK-NEXT: %[[RES:.*]] = tensor.dim %[[TENSOR]], %[[IDX]]
   %0 = tensor.generate %size {
     ^bb0(%arg0: index):
@@ -200,7 +200,7 @@ func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index
 // CHECK-LABEL: func @extract_from_tensor.generate_2d
 // CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
 func @extract_from_tensor.generate_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index {
-  %size = rank %tensor : tensor<*xf32>
+  %size = tensor.rank %tensor : tensor<*xf32>
   // CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[TENSOR]], %[[IDX0]]
   // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[TENSOR]], %[[IDX1]]
   // CHECK-NEXT: %[[RES:.*]] = arith.addi %[[DIM0]], %[[DIM1]]
@@ -221,7 +221,7 @@ func @extract_from_tensor.generate_2d(%idx0: index, %idx1: index, %tensor: tenso
 // CHECK-LABEL: func @extract_from_tensor.generate_sideeffects
 // CHECK-SAME: %[[IDX:.*]]: index
 func @extract_from_tensor.generate_sideeffects(%idx: index, %tensor: tensor<*xf32>, %mem: memref<?xindex>) -> index {
-  %size = rank %tensor : tensor<*xf32>
+  %size = tensor.rank %tensor : tensor<*xf32>
   // CHECK: %[[DTENSOR:.*]] = tensor.generate
   %0 = tensor.generate %size {
     ^bb0(%arg0: index):
@@ -900,3 +900,18 @@ func @reshape_splat_constant_float64() -> tensor<2x4x2xf64> {
 //       CHECK:   %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xf64>
 //   CHECK-NOT:   tensor.expand_shape
 //       CHECK:   return %[[CST]]
+
+// -----
+
+// CHECK-LABEL: func @fold_rank
+func @fold_rank() -> (index) {
+  %const_0 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]>
+    : tensor<2x1x4xi32>
+
+  // Fold a ank into a constant
+  // CHECK-NEXT: [[C3:%.+]] = arith.constant 3 : index
+  %rank_0 = tensor.rank %const_0 : tensor<2x1x4xi32>
+
+  // CHECK-NEXT: return [[C3]]
+  return %rank_0 : index
+}

diff  --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 8b40ec80e02d4..564526f16370f 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -292,3 +292,11 @@ func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor<?x4x5xf32>)
       : tensor<?x4x5xf32> into tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
+
+// -----
+
+func @rank(%0: f32) {
+  // expected-error at +1 {{'tensor.rank' op operand #0 must be tensor of any type values}}
+  "tensor.rank"(%0): (f32)->index
+  return
+}

diff  --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 63afc1f382b37..8d50d15184218 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -160,3 +160,14 @@ func @legal_collapsing_reshape_dynamic_tensor
 //      CHECK: func @legal_collapsing_reshape_dynamic_tensor
 //      CHECK:   tensor.collapse_shape
 // CHECK-SAME:    [0], [1], [2, 3, 4]
+
+// -----
+
+func @rank(%t : tensor<4x4x?xf32>) {
+  // CHECK: %{{.*}} = tensor.rank %{{.*}} : tensor<4x4x?xf32>
+  %0 = "tensor.rank"(%t) : (tensor<4x4x?xf32>) -> index
+
+  // CHECK: %{{.*}} = tensor.rank %{{.*}} : tensor<4x4x?xf32>
+  %1 = tensor.rank %t : tensor<4x4x?xf32>
+  return
+}

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index fe2d7207d3d01..b83f530eeacc6 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -99,12 +99,6 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
   // CHECK: %{{.*}} = arith.cmpf oeq, %{{.*}}, %{{.*}}: vector<4xf32>
   %70 = arith.cmpf oeq, %vcf32, %vcf32 : vector<4 x f32>
 
-  // CHECK: %{{.*}} = rank %arg0 : tensor<4x4x?xf32>
-  %71 = "std.rank"(%t) : (tensor<4x4x?xf32>) -> index
-
-  // CHECK: %{{.*}} = rank %arg0 : tensor<4x4x?xf32>
-  %72 = rank %t : tensor<4x4x?xf32>
-
   // CHECK: = constant unit
   %73 = constant unit
 

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 13cfd16daf9a4..49f29f09bf492 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -1,13 +1,5 @@
 // RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics
 
-func @rank(f32) {
-^bb(%0: f32):
-  "std.rank"(%0): (f32)->index // expected-error {{'std.rank' op operand #0 must be any memref or tensor type}}
-
-  return
-}
-
-// -----
 func @affine_apply_no_map() {
 ^bb0:
   %i = arith.constant 0 : index

diff  --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir
index 5406a8588ce4b..2e720eae3439c 100644
--- a/mlir/test/Transforms/constant-fold.mlir
+++ b/mlir/test/Transforms/constant-fold.mlir
@@ -754,32 +754,6 @@ func @cmpf_inf() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1,
 
 // -----
 
-// CHECK-LABEL: func @fold_rank
-func @fold_rank() -> (index) {
-  %const_0 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
-
-  // Fold a rank into a constant
-  // CHECK-NEXT: [[C3:%.+]] = arith.constant 3 : index
-  %rank_0 = rank %const_0 : tensor<2x1x4xi32>
-
-  // CHECK-NEXT: return [[C3]]
-  return %rank_0 : index
-}
-
-// -----
-
-// CHECK-LABEL: func @fold_rank_memref
-func @fold_rank_memref(%arg0 : memref<?x?xf32>) -> (index) {
-  // Fold a rank into a constant
-  // CHECK-NEXT: [[C2:%.+]] = arith.constant 2 : index
-  %rank_0 = rank %arg0 : memref<?x?xf32>
-
-  // CHECK-NEXT: return [[C2]]
-  return %rank_0 : index
-}
-
-// -----
-
 // CHECK-LABEL: func @nested_isolated_region
 func @nested_isolated_region() {
   // CHECK-NEXT: func @isolated_op

diff  --git a/mlir/test/Transforms/promote-buffers-to-stack.mlir b/mlir/test/Transforms/promote-buffers-to-stack.mlir
index c78f8a71dbb7b..2b6cd3185fa11 100644
--- a/mlir/test/Transforms/promote-buffers-to-stack.mlir
+++ b/mlir/test/Transforms/promote-buffers-to-stack.mlir
@@ -77,25 +77,25 @@ func @condBranchDynamicType(
 // -----
 
 // CHECK-LABEL: func @dynamicRanked
-func @dynamicRanked(%tensor: tensor<*xf32>) {
-  %0 = rank %tensor : tensor<*xf32>
+func @dynamicRanked(%memref: memref<*xf32>) {
+  %0 = memref.rank %memref : memref<*xf32>
   %1 = memref.alloc(%0) : memref<?xindex>
   return
 }
 
-// CHECK-NEXT: %[[RANK:.*]] = rank
+// CHECK-NEXT: %[[RANK:.*]] = memref.rank %{{.*}} : memref<*xf32>
 // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca(%[[RANK]])
 
 // -----
 
 // CHECK-LABEL: func @dynamicRanked2D
-func @dynamicRanked2D(%tensor: tensor<*xf32>) {
-  %0 = rank %tensor : tensor<*xf32>
+func @dynamicRanked2D(%memref: memref<*xf32>) {
+  %0 = memref.rank %memref : memref<*xf32>
   %1 = memref.alloc(%0, %0) : memref<?x?xindex>
   return
 }
 
-// CHECK-NEXT: %[[RANK:.*]] = rank
+// CHECK-NEXT: %[[RANK:.*]] = memref.rank %{{.*}} : memref<*xf32>
 //  RANK-NEXT: %[[ALLOC:.*]] = memref.alloca(%[[RANK]], %[[RANK]])
 // DEFINDEX-NEXT: %[[ALLOC:.*]] = memref.alloc(%[[RANK]], %[[RANK]])
 


        


More information about the Mlir-commits mailing list