[Mlir-commits] [mlir] 9c94908 - BEGIN_PUBLIC

Alexander Belyaev llvmlistbot at llvm.org
Fri Aug 7 05:33:09 PDT 2020


Author: Alexander Belyaev
Date: 2020-08-07T14:32:52+02:00
New Revision: 9c94908320549a1a2328c758d6bbb694466021e7

URL: https://github.com/llvm/llvm-project/commit/9c94908320549a1a2328c758d6bbb694466021e7
DIFF: https://github.com/llvm/llvm-project/commit/9c94908320549a1a2328c758d6bbb694466021e7.diff

LOG: BEGIN_PUBLIC
[mlir] Add support for unranked case for `tensor_store` and `tensor_load` ops.
END_PUBLIC

Differential Revision: https://reviews.llvm.org/D85518

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/IR/core-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index d9634fa2b9e6..088f262790d6 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2934,24 +2934,26 @@ def TensorLoadOp : Std_Op<"tensor_load",
     ```
   }];
 
-  let arguments = (ins Arg<AnyMemRef, "the reference to load from",
-                           [MemRead]>:$memref);
+  let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
+                       "the reference to load from", [MemRead]>:$memref);
   let results = (outs AnyTensor:$result);
   // TensorLoadOp is fully verified by traits.
   let verifier = ?;
 
   let builders = [OpBuilder<
     "OpBuilder &builder, OperationState &result, Value memref", [{
-      auto memrefType = memref.getType().cast<MemRefType>();
-      auto resultType = RankedTensorType::get(memrefType.getShape(),
-                                              memrefType.getElementType());
       result.addOperands(memref);
-      result.addTypes(resultType);
+      result.addTypes(getTensorTypeFromMemRefType(memref.getType()));
   }]>];
 
   let extraClassDeclaration = [{
     /// The result of a tensor_load is always a tensor.
-    TensorType getType() { return getResult().getType().cast<TensorType>(); }
+    TensorType getType() { 
+      Type resultType = getResult().getType();
+      if (resultType.isa<TensorType>())
+        return resultType.cast<TensorType>();
+      return {};
+    }
   }];
 
   let assemblyFormat = "$memref attr-dict `:` type($memref)";
@@ -2981,9 +2983,8 @@ def TensorStoreOp : Std_Op<"tensor_store",
     ```
   }];
 
-  let arguments = (ins AnyTensor:$tensor,
-                       Arg<AnyMemRef, "the reference to store to",
-                           [MemWrite]>:$memref);
+  let arguments = (ins AnyTensor:$tensor, Arg<AnyRankedOrUnrankedMemRef,
+                       "the reference to store to", [MemWrite]>:$memref);
   // TensorStoreOp is fully verified by traits.
   let verifier = ?;
 

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index d084620f3a03..74e1e20ac1a9 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2985,6 +2985,17 @@ OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
 static Type getTensorTypeFromMemRefType(Type type) {
   if (auto memref = type.dyn_cast<MemRefType>())
     return RankedTensorType::get(memref.getShape(), memref.getElementType());
+  if (auto memref = type.dyn_cast<UnrankedMemRefType>())
+    return UnrankedTensorType::get(memref.getElementType());
+  return NoneType::get(type.getContext());
+}
+
+static Type getMemRefTypeFromTensorType(Type type) {
+  if (auto tensor = type.dyn_cast<MemRefType>())
+    return MemRefType::get(tensor.getShape(), tensor.getElementType());
+  if (auto tensor = type.dyn_cast<UnrankedMemRefType>())
+    return UnrankedMemRefType::get(tensor.getElementType(),
+                                   tensor.getMemorySpace());
   return NoneType::get(type.getContext());
 }
 

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 4c5fa8fb1eac..c45683c082e8 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -813,6 +813,15 @@ func @tensor_load_store(%0 : memref<4x4xi32>) {
   return
 }
 
+// CHECK-LABEL: func @unranked_tensor_load_store
+func @unranked_tensor_load_store(%0 : memref<*xi32>) {
+  // CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF:.*]] : memref<*xi32>
+  %1 = tensor_load %0 : memref<*xi32>
+  // CHECK: tensor_store %[[TENSOR]], %[[MEMREF]] : memref<*xi32>
+  tensor_store %1, %0 : memref<*xi32>
+  return
+}
+
 // CHECK-LABEL: func @atomic_rmw
 // CHECK-SAME: ([[BUF:%.*]]: memref<10xf32>, [[VAL:%.*]]: f32, [[I:%.*]]: index)
 func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {


        


More information about the Mlir-commits mailing list