[Mlir-commits] [mlir] 5984d74 - [MLIR][Shape] Allow `get_extent` to operate on extent tensors and indices

Frederik Gossen llvmlistbot at llvm.org
Fri Jul 24 04:13:36 PDT 2020


Author: Frederik Gossen
Date: 2020-07-24T11:13:17Z
New Revision: 5984d74139d45cec5cc8c55b107b9cb5d801c03e

URL: https://github.com/llvm/llvm-project/commit/5984d74139d45cec5cc8c55b107b9cb5d801c03e
DIFF: https://github.com/llvm/llvm-project/commit/5984d74139d45cec5cc8c55b107b9cb5d801c03e.diff

LOG: [MLIR][Shape] Allow `get_extent` to operate on extent tensors and indices

Differential Revision: https://reviews.llvm.org/D84435

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
    mlir/test/Dialect/Shape/canonicalize.mlir
    mlir/test/Dialect/Shape/invalid.mlir
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 64dba487c507..32d6ebafff32 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -235,9 +235,10 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
     an error then it returns an error size.
   }];
   let arguments = (ins Shape_ShapeOrExtentTensorType:$shape,
-                       Shape_SizeType:$dim);
-  let results = (outs Shape_SizeType:$extent);
-  let assemblyFormat = "$shape `,` $dim `:` type($shape)  attr-dict";
+                       Shape_SizeOrIndexType:$dim);
+  let results = (outs Shape_SizeOrIndexType:$extent);
+  let assemblyFormat = "$shape `,` $dim `:` type($shape) `,` type($dim) `->` "
+                       "type($extent) attr-dict";
 
   let builders = [
     // Builder that allows passing a constant dimension as a simple integer.
@@ -251,6 +252,7 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
   }];
 
   let hasFolder = 1;
+  let verifier = [{ return ::verify(*this); }];
 }
 
 def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> {

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index a7a9cb97e76b..3bdc5cc39a7b 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -535,10 +535,30 @@ OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
 // GetExtentOp
 //===----------------------------------------------------------------------===//
 
+static LogicalResult verify(GetExtentOp op) {
+  Type shapeTy = op.shape().getType();
+  Type dimTy = op.dim().getType();
+  Type extentTy = op.extent().getType();
+  bool errorPropagationPossible =
+      shapeTy.isa<ShapeType>() || dimTy.isa<SizeType>();
+  if (errorPropagationPossible) {
+    if (!extentTy.isa<SizeType>())
+      op.emitError()
+          << "if at least one of the operands can hold error values then the "
+             "result must be of type `size` to propagate them";
+  } else {
+    if (extentTy.isa<SizeType>())
+      op.emitError() << "if none of the operands can hold error values then "
+                        "the result must be of type `index`";
+  }
+  return success();
+}
+
 Optional<int64_t> GetExtentOp::getConstantDim() {
-  if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) {
+  if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
     return constSizeOp.value().getLimitedValue();
-  }
+  if (auto constantOp = dim().getDefiningOp<ConstantOp>())
+    return constantOp.value().cast<IntegerAttr>().getInt();
   return llvm::None;
 }
 
@@ -558,8 +578,14 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
                         int64_t dim) {
   auto loc = result.location;
   auto dimAttr = builder.getIndexAttr(dim);
-  Value dimValue = builder.create<ConstSizeOp>(loc, dimAttr);
-  build(builder, result, shape, dimValue);
+  if (shape.getType().isa<ShapeType>()) {
+    Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
+    build(builder, result, builder.getType<SizeType>(), shape, dim);
+  } else {
+    Value dim =
+        builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
+    build(builder, result, builder.getIndexType(), shape, dim);
+  }
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index d8c0cbd5f9de..441024e9773e 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -136,28 +136,25 @@ func @rank(%shape : tensor<?xindex>) -> index {
 // `shape_of` operation.
 // CHECK-LABEL: @get_extent_shape_of
 // CHECK-SAME:  (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index
-func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : !shape.size)
-    -> !shape.size {
+func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : index) -> index {
   // CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
   // CHECK: return %[[RESULT]] : index
   %shape = shape.shape_of %arg : tensor<2x3xf32> -> tensor<?xindex>
-  %result = shape.get_extent %shape, %idx : tensor<?xindex>
-  return %result : !shape.size
+  %result = shape.get_extent %shape, %idx : tensor<?xindex>, index -> index
+  return %result : index
 }
 
 // -----
 
-// Express `get_extent` as `std.extract_element` when it relies directly on the
-// outcome of a `from_extent_tensor` operation.
+// Express `get_extent` as `std.extract_element`.
 // CHECK-LABEL: @get_extent_from_extent_tensor
 // CHECK-SAME: (%[[EXTENTS:.*]]: tensor<?xindex>, %[[IDX:.*]]: index) -> index
-func @get_extent_from_extent_tensor(%extents : tensor<?xindex>,
-                                    %idx : !shape.size) -> !shape.size {
+func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index)
+    -> index {
   // CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor<?xindex>
   // CHECK: return %[[RESULT]] : index
-  %shape = shape.from_extent_tensor %extents : tensor<?xindex>
-  %result = shape.get_extent %shape, %idx : !shape.shape
-  return %result : !shape.size
+  %result = shape.get_extent %extents, %idx : tensor<?xindex>, index -> index
+  return %result : index
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 4d8fca8d1318..b4dca5e3c2bf 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -235,13 +235,49 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
 
 // -----
 
+// Basic folding.
+// CHECK-LABEL: func @basic
+func @basic() -> index {
+  // CHECK: constant 2 : index
+  %0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
+  %c2 = constant 2 : index
+  %1 = shape.get_extent %0, %c2 : tensor<?xindex>, index -> index
+  return %1 : index
+}
+
+// -----
+
+// Should not fold.
+// CHECK-LABEL: func @out_of_bounds
+func @out_of_bounds() -> index {
+  // CHECK: shape.const_shape
+  // CHECK: shape.get_extent
+  %0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
+  %c3 = constant 3 : index
+  %1 = shape.get_extent %0, %c3 : tensor<?xindex>, index -> index
+  return %1 : index
+}
+
+// -----
+
+// Should not fold.
+// CHECK-LABEL: func @not_const
+func @not_const(%arg0: tensor<?xindex>) -> index {
+  // CHECK: shape.get_extent
+  %c3 = constant 3 : index
+  %0 = shape.get_extent %arg0, %c3 : tensor<?xindex>, index -> index
+  return %0 : index
+}
+
+// -----
+
 // Basic folding.
 // CHECK-LABEL: func @basic
 func @basic() -> !shape.size {
   // CHECK: shape.const_size 2
-  %0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
+  %0 = shape.const_shape [0, 1, 2] : !shape.shape
   %c2 = shape.const_size 2
-  %1 = shape.get_extent %0, %c2 : tensor<?xindex>
+  %1 = shape.get_extent %0, %c2 : !shape.shape, !shape.size -> !shape.size
   return %1 : !shape.size
 }
 
@@ -252,9 +288,9 @@ func @basic() -> !shape.size {
 func @out_of_bounds() -> !shape.size {
   // CHECK: shape.const_shape
   // CHECK: shape.get_extent
-  %0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
+  %0 = shape.const_shape [0, 1, 2] : !shape.shape
   %c3 = shape.const_size 3
-  %1 = shape.get_extent %0, %c3 : tensor<?xindex>
+  %1 = shape.get_extent %0, %c3 : !shape.shape, !shape.size -> !shape.size
   return %1 : !shape.size
 }
 
@@ -262,14 +298,13 @@ func @out_of_bounds() -> !shape.size {
 
 // Should not fold.
 // CHECK-LABEL: func @not_const
-func @not_const(%arg0: tensor<?xindex>) -> !shape.size {
+func @not_const(%arg0 : !shape.shape) -> !shape.size {
   // CHECK: shape.get_extent
   %c3 = shape.const_size 3
-  %0 = shape.get_extent %arg0, %c3 : tensor<?xindex>
+  %0 = shape.get_extent %arg0, %c3 : !shape.shape, !shape.size -> !shape.size
   return %0 : !shape.size
 }
 
-
 // -----
 // cstr_eq with non-constant but known equal shapes can be removed.
 // CHECK-LABEL: func @f

diff  --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index ae25ba90c360..d7e9e40ed3f2 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -102,3 +102,21 @@ func @rank(%arg : !shape.shape) {
   %0 = shape.rank %arg : !shape.shape -> index
 }
 
+// -----
+
+func @get_extent_error_free(%arg : tensor<?xindex>) -> !shape.size {
+  %c0 = constant 0 : index
+  // expected-error at +1 {{if none of the operands can hold error values then the result must be of type `index`}}
+  %result = shape.get_extent %arg, %c0 : tensor<?xindex>, index -> !shape.size
+  return %result : !shape.size
+}
+
+// -----
+
+func @get_extent_error_possible(%arg : tensor<?xindex>) -> index {
+  %c0 = shape.const_size 0
+  // expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
+  %result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> index
+  return %result : index
+}
+

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 3b44af99b4fe..b6b839251a88 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -163,13 +163,20 @@ func @shape_eq_on_mixed(%a : tensor<?xindex>, %b : !shape.shape) -> i1 {
 
 func @get_extent_on_shape(%arg : !shape.shape) -> !shape.size {
   %c0 = shape.const_size 0
-  %result = shape.get_extent %arg, %c0 : !shape.shape
+  %result = shape.get_extent %arg, %c0 :
+      !shape.shape, !shape.size -> !shape.size
   return %result : !shape.size
 }
 
-func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> !shape.size {
+func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> index {
+  %c0 = constant 0 : index
+  %result = shape.get_extent %arg, %c0 : tensor<?xindex>, index -> index
+  return %result : index
+}
+
+func @get_extent_on_mixed_operands(%arg : tensor<?xindex>) -> !shape.size {
   %c0 = shape.const_size 0
-  %result = shape.get_extent %arg, %c0 : tensor<?xindex>
+  %result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> !shape.size
   return %result : !shape.size
 }
 


        


More information about the Mlir-commits mailing list