[Mlir-commits] [mlir] 84a6da6 - [mlir] Fix some edge cases around 0-element TensorFromElementsOp

Sean Silva llvmlistbot at llvm.org
Fri Sep 11 10:58:51 PDT 2020


Author: Sean Silva
Date: 2020-09-11T10:58:35-07:00
New Revision: 84a6da67e6b2a76b15ad1862f4cbb7625fe318df

URL: https://github.com/llvm/llvm-project/commit/84a6da67e6b2a76b15ad1862f4cbb7625fe318df
DIFF: https://github.com/llvm/llvm-project/commit/84a6da67e6b2a76b15ad1862f4cbb7625fe318df.diff

LOG: [mlir] Fix some edge cases around 0-element TensorFromElementsOp

This introduces a builder for the more general case that supports zero
elements (where the element type can't be inferred from the ValueRange,
since it might be empty).

Also, fix up some cases in ShapeToStandard lowering that hit this. It
happens very easily when dealing with shapes of 0-D tensors.

The SameOperandsAndResultElementType is redundant with the new
TypesMatchWith and prevented having zero elements.

Differential Revision: https://reviews.llvm.org/D87492

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
    mlir/test/IR/core-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index ec7ecf9b92d4..afdc3edae86c 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1613,7 +1613,6 @@ def ExtractElementOp : Std_Op<"extract_element",
 
 def TensorFromElementsOp : Std_Op<"tensor_from_elements", [
     NoSideEffect,
-    SameOperandsAndResultElementType,
     TypesMatchWith<"operand types match result element type",
                    "result", "elements", "SmallVector<Type, 2>("
                    "$_self.cast<ShapedType>().getDimSize(0), "
@@ -1638,7 +1637,11 @@ def TensorFromElementsOp : Std_Op<"tensor_from_elements", [
   // This op is fully verified by its traits.
   let verifier = ?;
 
+  let skipDefaultBuilders = 1;
   let builders = [
+    OpBuilder<"OpBuilder &b, OperationState &result, Type elementType,"
+    "ValueRange elements">,
+    // Special case builder for when `elements` has size >=1.
     OpBuilder<"OpBuilder &b, OperationState &result, ValueRange elements">
   ];
 

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index f3f11e89af02..0a6953842a14 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -182,8 +182,9 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
     extentOperands.push_back(
         rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
   }
-  Value tensor = rewriter.create<TensorFromElementsOp>(loc, extentOperands);
   Type indexTy = rewriter.getIndexType();
+  Value tensor =
+      rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands);
   Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
   rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
   return success();
@@ -444,8 +445,8 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
     }
 
     // Materialize extent tensor.
-    Value staticExtentTensor =
-        rewriter.create<TensorFromElementsOp>(loc, extentValues);
+    Value staticExtentTensor = rewriter.create<TensorFromElementsOp>(
+        loc, rewriter.getIndexType(), extentValues);
     rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
                                               op.getType());
     return success();

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index dc45d5175277..cf085a604b46 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1756,12 +1756,18 @@ OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
 // TensorFromElementsOp
 //===----------------------------------------------------------------------===//
 
+void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result,
+                                 Type elementType, ValueRange elements) {
+  Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
+                                        elementType);
+  result.addOperands(elements);
+  result.addTypes(resultTy);
+}
+
 void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result,
                                  ValueRange elements) {
   assert(!elements.empty() && "expected at least one element");
-  Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
-                                        elements.front().getType());
-  build(builder, result, resultTy, elements);
+  build(builder, result, elements.front().getType(), elements);
 }
 
 namespace {

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 4168634f1240..01ba6abcc6c4 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -103,6 +103,19 @@ func @const_shape() -> tensor<?xindex> {
 
 // -----
 
+// Lower `const_shape` in the case of rank 0.
+// CHECK-LABEL: func @const_shape_zero_elements
+// CHECK-SAME: () -> tensor<?xindex>
+func @const_shape_zero_elements() -> tensor<?xindex> {
+  // CHECK: %[[TENSOR:.*]] = tensor_from_elements : tensor<0xindex>
+  // CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex>
+  // CHECK: return %[[RESULT]] : tensor<?xindex>
+  %shape = shape.const_shape [] : tensor<?xindex>
+  return %shape : tensor<?xindex>
+}
+
+// -----
+
 // Lower `any` to its first operand.
 // CHECK-LABEL: @any_of_three
 // CHECK-SAME:  (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex>
@@ -227,6 +240,17 @@ func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
 
 // -----
 
+// Lower `shape_of` for 0-D tensor.
+// CHECK-LABEL: @shape_of_zero_d
+// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>)
+func @shape_of_zero_d(%arg : tensor<f32>) {
+  // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements : tensor<0xindex>
+  %shape = shape.shape_of %arg : tensor<f32> -> tensor<?xindex>
+  return
+}
+
+// -----
+
 // Lower `shape_of` for dynamically shaped tensor.
 // CHECK-LABEL: @shape_of_dyn
 // CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>)

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index e4472b444f03..f182936c8703 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -673,6 +673,9 @@ func @tensor_from_elements() {
   // CHECK: %2 = tensor_from_elements [[C0_F32]] : tensor<1xf32>
   %2 = tensor_from_elements %c0_f32 : tensor<1xf32>
 
+  // CHECK: tensor_from_elements : tensor<0xindex>
+  %3 = tensor_from_elements : tensor<0xindex>
+
   return
 }
 


        


More information about the Mlir-commits mailing list