[Mlir-commits] [mlir] f556af9 - [mlir] Fix materializations for unranked tensors.
Sean Silva
llvmlistbot at llvm.org
Wed Nov 4 10:17:13 PST 2020
Author: Sean Silva
Date: 2020-11-04T10:16:55-08:00
New Revision: f556af965f11cfe614d722f59257ba116bee3f62
URL: https://github.com/llvm/llvm-project/commit/f556af965f11cfe614d722f59257ba116bee3f62
DIFF: https://github.com/llvm/llvm-project/commit/f556af965f11cfe614d722f59257ba116bee3f62.diff
LOG: [mlir] Fix materializations for unranked tensors.
Differential Revision: https://reviews.llvm.org/D90656
Added:
Modified:
mlir/lib/Transforms/Bufferize.cpp
mlir/test/Dialect/Standard/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp
index 1564290cce4a..790b6d0ab9f2 100644
--- a/mlir/lib/Transforms/Bufferize.cpp
+++ b/mlir/lib/Transforms/Bufferize.cpp
@@ -27,13 +27,13 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
addConversion([](UnrankedTensorType type) -> Type {
return UnrankedMemRefType::get(type.getElementType(), 0);
});
- addSourceMaterialization([](OpBuilder &builder, RankedTensorType type,
+ addSourceMaterialization([](OpBuilder &builder, TensorType type,
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<BaseMemRefType>());
return builder.create<TensorLoadOp>(loc, type, inputs[0]);
});
- addTargetMaterialization([](OpBuilder &builder, MemRefType type,
+ addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<TensorType>());
diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir
index b2cefe32120e..8cc05ff20644 100644
--- a/mlir/test/Dialect/Standard/bufferize.mlir
+++ b/mlir/test/Dialect/Standard/bufferize.mlir
@@ -86,6 +86,28 @@ func @tensor_cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
return %0 : tensor<2xindex>
}
+// CHECK-LABEL: func @tensor_cast_from_unranked(
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> {
+// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<*xf32>
+// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<*xf32> to memref<2xf32>
+// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<2xf32>
+// CHECK: return %[[RET]] : tensor<2xf32>
+func @tensor_cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> {
+ %0 = tensor_cast %arg0 : tensor<*xf32> to tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL: func @tensor_cast_to_unranked(
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<2xf32>) -> tensor<*xf32> {
+// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<2xf32>
+// CHECK: %[[CASTED_MEMREF:.*]] = memref_cast %[[MEMREF]] : memref<2xf32> to memref<*xf32>
+// CHECK: %[[RET:.*]] = tensor_load %[[CASTED_MEMREF]] : memref<*xf32>
+// CHECK: return %[[RET]] : tensor<*xf32>
+func @tensor_cast_to_unranked(%arg0: tensor<2xf32>) -> tensor<*xf32> {
+ %0 = tensor_cast %arg0 : tensor<2xf32> to tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
// CHECK-LABEL: func @tensor_from_elements(
// CHECK-SAME: %[[ELEM0:.*]]: index,
// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> {
More information about the Mlir-commits
mailing list