[Mlir-commits] [mlir] f556af9 - [mlir] Fix materializations for unranked tensors.

Sean Silva llvmlistbot at llvm.org
Wed Nov 4 10:17:13 PST 2020


Author: Sean Silva
Date: 2020-11-04T10:16:55-08:00
New Revision: f556af965f11cfe614d722f59257ba116bee3f62

URL: https://github.com/llvm/llvm-project/commit/f556af965f11cfe614d722f59257ba116bee3f62
DIFF: https://github.com/llvm/llvm-project/commit/f556af965f11cfe614d722f59257ba116bee3f62.diff

LOG: [mlir] Fix materializations for unranked tensors.

Differential Revision: https://reviews.llvm.org/D90656

Added: 
    

Modified: 
    mlir/lib/Transforms/Bufferize.cpp
    mlir/test/Dialect/Standard/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp
index 1564290cce4a..790b6d0ab9f2 100644
--- a/mlir/lib/Transforms/Bufferize.cpp
+++ b/mlir/lib/Transforms/Bufferize.cpp
@@ -27,13 +27,13 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
   addConversion([](UnrankedTensorType type) -> Type {
     return UnrankedMemRefType::get(type.getElementType(), 0);
   });
-  addSourceMaterialization([](OpBuilder &builder, RankedTensorType type,
+  addSourceMaterialization([](OpBuilder &builder, TensorType type,
                               ValueRange inputs, Location loc) -> Value {
     assert(inputs.size() == 1);
     assert(inputs[0].getType().isa<BaseMemRefType>());
     return builder.create<TensorLoadOp>(loc, type, inputs[0]);
   });
-  addTargetMaterialization([](OpBuilder &builder, MemRefType type,
+  addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
                               ValueRange inputs, Location loc) -> Value {
     assert(inputs.size() == 1);
     assert(inputs[0].getType().isa<TensorType>());

diff  --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir
index b2cefe32120e..8cc05ff20644 100644
--- a/mlir/test/Dialect/Standard/bufferize.mlir
+++ b/mlir/test/Dialect/Standard/bufferize.mlir
@@ -86,6 +86,28 @@ func @tensor_cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
   return %0 : tensor<2xindex>
 }
 
+// CHECK-LABEL:   func @tensor_cast_from_unranked(
+// CHECK-SAME:                                    %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> {
+// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<*xf32>
+// CHECK:           %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<*xf32> to memref<2xf32>
+// CHECK:           %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<2xf32>
+// CHECK:           return %[[RET]] : tensor<2xf32>
+func @tensor_cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> {
+  %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL:   func @tensor_cast_to_unranked(
+// CHECK-SAME:                                  %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> {
+// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<2xf32>
+// CHECK:           %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<2xf32> to memref<*xf32>
+// CHECK:           %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<*xf32>
+// CHECK:           return %[[RET]] : tensor<*xf32>
+func @tensor_cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> {
+  %0 = tensor_cast %arg0 : tensor<2xf32> to tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
 // CHECK-LABEL:   func @tensor_from_elements(
 // CHECK-SAME:                               %[[ELEM0:.*]]: index,
 // CHECK-SAME:                               %[[ELEM1:.*]]: index) -> tensor<2xindex> {


        


More information about the Mlir-commits mailing list