[Mlir-commits] [mlir] cf42b70 - [mlir][shape] Add `shape.get_extent`.
Sean Silva
llvmlistbot at llvm.org
Tue May 26 17:06:42 PDT 2020
Author: Sean Silva
Date: 2020-05-26T17:03:40-07:00
New Revision: cf42b704391c44e84485dd2547ae006196998266
URL: https://github.com/llvm/llvm-project/commit/cf42b704391c44e84485dd2547ae006196998266
DIFF: https://github.com/llvm/llvm-project/commit/cf42b704391c44e84485dd2547ae006196998266.diff
LOG: [mlir][shape] Add `shape.get_extent`.
Summary:
This op extracts an extent from a shape.
This also is the first op which constant folds to shape.const_size,
which revealed that shape.const_size needs a folder (ConstantLike ops
seem to always need folders for the constant folding infra to work).
Differential Revision: https://reviews.llvm.org/D80394
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 7d62cebff8e6..0278d7bbeb06 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -131,6 +131,7 @@ def Shape_ConstSizeOp : Shape_Op<"const_size",
let results = (outs Shape_SizeType:$result);
let assemblyFormat = "attr-dict $value";
+ let hasFolder = 1;
}
def Shape_FromExtentsOp : Shape_Op<"from_extents", [
@@ -190,6 +191,37 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", []> {
let hasFolder = 1;
}
+def Shape_GetExtentOp : Shape_Op<"get_extent",
+ [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ let summary = "Gets the specified extent from a shape";
+ let description = [{
+ Gets the extent indexed by `dim` from `shape`.
+
+ If the shape is an error, it returns an error size.
+ }];
+ let arguments = (ins
+ Shape_ShapeType:$shape,
+ Confined<I64Attr, [IntNonNegative]>:$dim
+ );
+ let results = (outs Shape_SizeType:$extent);
+ let assemblyFormat = "$shape `,` $dim attr-dict";
+
+ let builders = [
+ // Builder that allows passing a simple integer instead of an IntegerAttr.
+ OpBuilder<
+ [{
+ OpBuilder &builder, OperationState &result,
+ Value shape, int64_t dim
+ }],
+ [{
+ build(builder, result, shape, builder.getI64IntegerAttr(dim));
+ }]
+ >
+ ];
+
+ let hasFolder = 1;
+}
+
def Shape_JoinOp : Shape_Op<"join", []> {
let summary = "Returns the least general shape.size of its operands";
let description = [{
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 5c6a0c2204c3..095c41720fba 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -245,6 +245,8 @@ ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
return success();
}
+OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
+
//===----------------------------------------------------------------------===//
// FromExtentsOp
//===----------------------------------------------------------------------===//
@@ -267,6 +269,37 @@ OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
return builder.getI64TensorAttr(extents);
}
+//===----------------------------------------------------------------------===//
+// GetExtentOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+GetExtentOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
+ ValueRange operands, DictionaryAttr attributes,
+ RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ inferredReturnTypes.push_back(SizeType::get(context));
+ return success();
+}
+
+OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
+ auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
+ if (!elements)
+ return nullptr;
+ uint64_t dimToGet = dim().getLimitedValue();
+ // TODO: Constant fold this to some kind of constant error.
+ if (dimToGet >= (uint64_t)elements.getNumElements())
+ return nullptr;
+ // This is a little inconvenient because getValue returns an IntegerAttr
+ // that is not of IndexType, but the result here needs to be of
+ // IndexType.
+ // TODO: Make ConstShapeOp hold an tensor of index instead of i64.
+ Builder builder(getContext());
+ return builder.getIntegerAttr(
+ builder.getIndexType(),
+ elements.getValue<IntegerAttr>({dimToGet}).getInt());
+}
+
//===----------------------------------------------------------------------===//
// ShapeOfOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 2e35fc748d86..018f5b212b4e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -106,3 +106,33 @@ func @no_fold(%arg0: index) -> !shape.shape {
%ret = shape.from_extents %e0, %arg0
return %ret : !shape.shape
}
+
+// -----
+// Canonicalization of shape.get_extent
+
+// Basic folding.
+// CHECK-LABEL: func @basic
+func @basic() -> !shape.size {
+ // CHECK: shape.const_size 2
+ %0 = shape.const_shape [0, 1, 2]
+ %1 = shape.get_extent %0, 2
+ return %1 : !shape.size
+}
+
+// Should not fold.
+// CHECK-LABEL: func @out_of_bounds
+func @out_of_bounds() -> !shape.size {
+ // CHECK: shape.const_shape
+ // CHECK: shape.get_extent
+ %0 = shape.const_shape [0, 1, 2]
+ %1 = shape.get_extent %0, 3
+ return %1 : !shape.size
+}
+
+// Should not fold.
+// CHECK-LABEL: func @not_const
+func @not_const(%arg0: !shape.shape) -> !shape.size {
+ // CHECK: shape.get_extent
+ %0 = shape.get_extent %arg0, 3
+ return %0 : !shape.size
+}
More information about the Mlir-commits
mailing list