[Mlir-commits] [mlir] 84a6da6 - [mlir] Fix some edge cases around 0-element TensorFromElementsOp
Sean Silva
llvmlistbot at llvm.org
Fri Sep 11 10:58:51 PDT 2020
Author: Sean Silva
Date: 2020-09-11T10:58:35-07:00
New Revision: 84a6da67e6b2a76b15ad1862f4cbb7625fe318df
URL: https://github.com/llvm/llvm-project/commit/84a6da67e6b2a76b15ad1862f4cbb7625fe318df
DIFF: https://github.com/llvm/llvm-project/commit/84a6da67e6b2a76b15ad1862f4cbb7625fe318df.diff
LOG: [mlir] Fix some edge cases around 0-element TensorFromElementsOp
This introduces a builder for the more general case that supports zero
elements (where the element type can't be inferred from the ValueRange,
since it might be empty).
Also, fix up some cases in ShapeToStandard lowering that hit this. It
happens very easily when dealing with shapes of 0-D tensors.
The SameOperandsAndResultElementType is redundant with the new
TypesMatchWith and prevented having zero elements.
Differential Revision: https://reviews.llvm.org/D87492
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
mlir/test/IR/core-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index ec7ecf9b92d4..afdc3edae86c 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1613,7 +1613,6 @@ def ExtractElementOp : Std_Op<"extract_element",
def TensorFromElementsOp : Std_Op<"tensor_from_elements", [
NoSideEffect,
- SameOperandsAndResultElementType,
TypesMatchWith<"operand types match result element type",
"result", "elements", "SmallVector<Type, 2>("
"$_self.cast<ShapedType>().getDimSize(0), "
@@ -1638,7 +1637,11 @@ def TensorFromElementsOp : Std_Op<"tensor_from_elements", [
// This op is fully verified by its traits.
let verifier = ?;
+ let skipDefaultBuilders = 1;
let builders = [
+ OpBuilder<"OpBuilder &b, OperationState &result, Type elementType,"
+ "ValueRange elements">,
+ // Special case builder for when `elements` has size >=1.
OpBuilder<"OpBuilder &b, OperationState &result, ValueRange elements">
];
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index f3f11e89af02..0a6953842a14 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -182,8 +182,9 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
extentOperands.push_back(
rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
}
- Value tensor = rewriter.create<TensorFromElementsOp>(loc, extentOperands);
Type indexTy = rewriter.getIndexType();
+ Value tensor =
+ rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands);
Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
return success();
@@ -444,8 +445,8 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
}
// Materialize extent tensor.
- Value staticExtentTensor =
- rewriter.create<TensorFromElementsOp>(loc, extentValues);
+ Value staticExtentTensor = rewriter.create<TensorFromElementsOp>(
+ loc, rewriter.getIndexType(), extentValues);
rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
op.getType());
return success();
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index dc45d5175277..cf085a604b46 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1756,12 +1756,18 @@ OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
// TensorFromElementsOp
//===----------------------------------------------------------------------===//
+void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result,
+ Type elementType, ValueRange elements) {
+ Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
+ elementType);
+ result.addOperands(elements);
+ result.addTypes(resultTy);
+}
+
void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result,
ValueRange elements) {
assert(!elements.empty() && "expected at least one element");
- Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
- elements.front().getType());
- build(builder, result, resultTy, elements);
+ build(builder, result, elements.front().getType(), elements);
}
namespace {
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 4168634f1240..01ba6abcc6c4 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -103,6 +103,19 @@ func @const_shape() -> tensor<?xindex> {
// -----
+// Lower `const_shape` in the case of rank 0.
+// CHECK-LABEL: func @const_shape_zero_elements
+// CHECK-SAME: () -> tensor<?xindex>
+func @const_shape_zero_elements() -> tensor<?xindex> {
+ // CHECK: %[[TENSOR:.*]] = tensor_from_elements : tensor<0xindex>
+ // CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex>
+ // CHECK: return %[[RESULT]] : tensor<?xindex>
+ %shape = shape.const_shape [] : tensor<?xindex>
+ return %shape : tensor<?xindex>
+}
+
+// -----
+
// Lower `any` to its first operand.
// CHECK-LABEL: @any_of_three
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex>
@@ -227,6 +240,17 @@ func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
// -----
+// Lower `shape_of` for 0-D tensor.
+// CHECK-LABEL: @shape_of_zero_d
+// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>)
+func @shape_of_zero_d(%arg : tensor<f32>) {
+ // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements : tensor<0xindex>
+ %shape = shape.shape_of %arg : tensor<f32> -> tensor<?xindex>
+ return
+}
+
+// -----
+
// Lower `shape_of` for dynamically shaped tensor.
// CHECK-LABEL: @shape_of_dyn
// CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>)
diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index e4472b444f03..f182936c8703 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -673,6 +673,9 @@ func @tensor_from_elements() {
// CHECK: %2 = tensor_from_elements [[C0_F32]] : tensor<1xf32>
%2 = tensor_from_elements %c0_f32 : tensor<1xf32>
+ // CHECK: tensor_from_elements : tensor<0xindex>
+ %3 = tensor_from_elements : tensor<0xindex>
+
return
}
More information about the Mlir-commits
mailing list