[Mlir-commits] [mlir] 5984d74 - [MLIR][Shape] Allow `get_extent` to operate on extent tensors and indices
Frederik Gossen
llvmlistbot at llvm.org
Fri Jul 24 04:13:36 PDT 2020
Author: Frederik Gossen
Date: 2020-07-24T11:13:17Z
New Revision: 5984d74139d45cec5cc8c55b107b9cb5d801c03e
URL: https://github.com/llvm/llvm-project/commit/5984d74139d45cec5cc8c55b107b9cb5d801c03e
DIFF: https://github.com/llvm/llvm-project/commit/5984d74139d45cec5cc8c55b107b9cb5d801c03e.diff
LOG: [MLIR][Shape] Allow `get_extent` to operate on extent tensors and indices
Differential Revision: https://reviews.llvm.org/D84435
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Shape/invalid.mlir
mlir/test/Dialect/Shape/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 64dba487c507..32d6ebafff32 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -235,9 +235,10 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
an error then it returns an error size.
}];
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape,
- Shape_SizeType:$dim);
- let results = (outs Shape_SizeType:$extent);
- let assemblyFormat = "$shape `,` $dim `:` type($shape) attr-dict";
+ Shape_SizeOrIndexType:$dim);
+ let results = (outs Shape_SizeOrIndexType:$extent);
+ let assemblyFormat = "$shape `,` $dim `:` type($shape) `,` type($dim) `->` "
+ "type($extent) attr-dict";
let builders = [
// Builder that allows passing a constant dimension as a simple integer.
@@ -251,6 +252,7 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
}];
let hasFolder = 1;
+ let verifier = [{ return ::verify(*this); }];
}
def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> {
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index a7a9cb97e76b..3bdc5cc39a7b 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -535,10 +535,30 @@ OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
// GetExtentOp
//===----------------------------------------------------------------------===//
+static LogicalResult verify(GetExtentOp op) {
+ Type shapeTy = op.shape().getType();
+ Type dimTy = op.dim().getType();
+ Type extentTy = op.extent().getType();
+ bool errorPropagationPossible =
+ shapeTy.isa<ShapeType>() || dimTy.isa<SizeType>();
+ if (errorPropagationPossible) {
+ if (!extentTy.isa<SizeType>())
+ op.emitError()
+ << "if at least one of the operands can hold error values then the "
+ "result must be of type `size` to propagate them";
+ } else {
+ if (extentTy.isa<SizeType>())
+ op.emitError() << "if none of the operands can hold error values then "
+ "the result must be of type `index`";
+ }
+ return success();
+}
+
Optional<int64_t> GetExtentOp::getConstantDim() {
- if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) {
+ if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
return constSizeOp.value().getLimitedValue();
- }
+ if (auto constantOp = dim().getDefiningOp<ConstantOp>())
+ return constantOp.value().cast<IntegerAttr>().getInt();
return llvm::None;
}
@@ -558,8 +578,14 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
int64_t dim) {
auto loc = result.location;
auto dimAttr = builder.getIndexAttr(dim);
- Value dimValue = builder.create<ConstSizeOp>(loc, dimAttr);
- build(builder, result, shape, dimValue);
+ if (shape.getType().isa<ShapeType>()) {
+ Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
+ build(builder, result, builder.getType<SizeType>(), shape, dim);
+ } else {
+ Value dim =
+ builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
+ build(builder, result, builder.getIndexType(), shape, dim);
+ }
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index d8c0cbd5f9de..441024e9773e 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -136,28 +136,25 @@ func @rank(%shape : tensor<?xindex>) -> index {
// `shape_of` operation.
// CHECK-LABEL: @get_extent_shape_of
// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index
-func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : !shape.size)
- -> !shape.size {
+func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : index) -> index {
// CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
// CHECK: return %[[RESULT]] : index
%shape = shape.shape_of %arg : tensor<2x3xf32> -> tensor<?xindex>
- %result = shape.get_extent %shape, %idx : tensor<?xindex>
- return %result : !shape.size
+ %result = shape.get_extent %shape, %idx : tensor<?xindex>, index -> index
+ return %result : index
}
// -----
-// Express `get_extent` as `std.extract_element` when it relies directly on the
-// outcome of a `from_extent_tensor` operation.
+// Express `get_extent` as `std.extract_element`.
// CHECK-LABEL: @get_extent_from_extent_tensor
// CHECK-SAME: (%[[EXTENTS:.*]]: tensor<?xindex>, %[[IDX:.*]]: index) -> index
-func @get_extent_from_extent_tensor(%extents : tensor<?xindex>,
- %idx : !shape.size) -> !shape.size {
+func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index)
+ -> index {
// CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor<?xindex>
// CHECK: return %[[RESULT]] : index
- %shape = shape.from_extent_tensor %extents : tensor<?xindex>
- %result = shape.get_extent %shape, %idx : !shape.shape
- return %result : !shape.size
+ %result = shape.get_extent %extents, %idx : tensor<?xindex>, index -> index
+ return %result : index
}
// -----
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 4d8fca8d1318..b4dca5e3c2bf 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -235,13 +235,49 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
// -----
+// Basic folding.
+// CHECK-LABEL: func @basic
+func @basic() -> index {
+ // CHECK: constant 2 : index
+ %0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
+ %c2 = constant 2 : index
+ %1 = shape.get_extent %0, %c2 : tensor<?xindex>, index -> index
+ return %1 : index
+}
+
+// -----
+
+// Should not fold.
+// CHECK-LABEL: func @out_of_bounds
+func @out_of_bounds() -> index {
+ // CHECK: shape.const_shape
+ // CHECK: shape.get_extent
+ %0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
+ %c3 = constant 3 : index
+ %1 = shape.get_extent %0, %c3 : tensor<?xindex>, index -> index
+ return %1 : index
+}
+
+// -----
+
+// Should not fold.
+// CHECK-LABEL: func @not_const
+func @not_const(%arg0: tensor<?xindex>) -> index {
+ // CHECK: shape.get_extent
+ %c3 = constant 3 : index
+ %0 = shape.get_extent %arg0, %c3 : tensor<?xindex>, index -> index
+ return %0 : index
+}
+
+// -----
+
// Basic folding.
// CHECK-LABEL: func @basic
func @basic() -> !shape.size {
// CHECK: shape.const_size 2
- %0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
+ %0 = shape.const_shape [0, 1, 2] : !shape.shape
%c2 = shape.const_size 2
- %1 = shape.get_extent %0, %c2 : tensor<?xindex>
+ %1 = shape.get_extent %0, %c2 : !shape.shape, !shape.size -> !shape.size
return %1 : !shape.size
}
@@ -252,9 +288,9 @@ func @basic() -> !shape.size {
func @out_of_bounds() -> !shape.size {
// CHECK: shape.const_shape
// CHECK: shape.get_extent
- %0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
+ %0 = shape.const_shape [0, 1, 2] : !shape.shape
%c3 = shape.const_size 3
- %1 = shape.get_extent %0, %c3 : tensor<?xindex>
+ %1 = shape.get_extent %0, %c3 : !shape.shape, !shape.size -> !shape.size
return %1 : !shape.size
}
@@ -262,14 +298,13 @@ func @out_of_bounds() -> !shape.size {
// Should not fold.
// CHECK-LABEL: func @not_const
-func @not_const(%arg0: tensor<?xindex>) -> !shape.size {
+func @not_const(%arg0 : !shape.shape) -> !shape.size {
// CHECK: shape.get_extent
%c3 = shape.const_size 3
- %0 = shape.get_extent %arg0, %c3 : tensor<?xindex>
+ %0 = shape.get_extent %arg0, %c3 : !shape.shape, !shape.size -> !shape.size
return %0 : !shape.size
}
-
// -----
// cstr_eq with non-constant but known equal shapes can be removed.
// CHECK-LABEL: func @f
diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index ae25ba90c360..d7e9e40ed3f2 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -102,3 +102,21 @@ func @rank(%arg : !shape.shape) {
%0 = shape.rank %arg : !shape.shape -> index
}
+// -----
+
+func @get_extent_error_free(%arg : tensor<?xindex>) -> !shape.size {
+ %c0 = constant 0 : index
+ // expected-error at +1 {{if none of the operands can hold error values then the result must be of type `index`}}
+ %result = shape.get_extent %arg, %c0 : tensor<?xindex>, index -> !shape.size
+ return %result : !shape.size
+}
+
+// -----
+
+func @get_extent_error_possible(%arg : tensor<?xindex>) -> index {
+ %c0 = shape.const_size 0
+ // expected-error at +1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
+ %result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> index
+ return %result : index
+}
+
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 3b44af99b4fe..b6b839251a88 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -163,13 +163,20 @@ func @shape_eq_on_mixed(%a : tensor<?xindex>, %b : !shape.shape) -> i1 {
func @get_extent_on_shape(%arg : !shape.shape) -> !shape.size {
%c0 = shape.const_size 0
- %result = shape.get_extent %arg, %c0 : !shape.shape
+ %result = shape.get_extent %arg, %c0 :
+ !shape.shape, !shape.size -> !shape.size
return %result : !shape.size
}
-func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> !shape.size {
+func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> index {
+ %c0 = constant 0 : index
+ %result = shape.get_extent %arg, %c0 : tensor<?xindex>, index -> index
+ return %result : index
+}
+
+func @get_extent_on_mixed_operands(%arg : tensor<?xindex>) -> !shape.size {
%c0 = shape.const_size 0
- %result = shape.get_extent %arg, %c0 : tensor<?xindex>
+ %result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> !shape.size
return %result : !shape.size
}
More information about the Mlir-commits
mailing list