[Mlir-commits] [mlir] e4184c8 - [MLIR][Shape] Make dimension an operand of `get_extent`
Frederik Gossen
llvmlistbot at llvm.org
Wed Jun 10 04:47:42 PDT 2020
Author: Frederik Gossen
Date: 2020-06-10T11:47:18Z
New Revision: e4184c84ca0662c73780e94763e83ec245b5a2b0
URL: https://github.com/llvm/llvm-project/commit/e4184c84ca0662c73780e94763e83ec245b5a2b0
DIFF: https://github.com/llvm/llvm-project/commit/e4184c84ca0662c73780e94763e83ec245b5a2b0.diff
LOG: [MLIR][Shape] Make dimension an operand of `get_extent`
The operation `get_extent` now accepts the dimension as an operand and is no
longer limited to constant dimensions.
A helper function facilitates the common constant use case.
Differential Revision: https://reviews.llvm.org/D81248
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 2393aa1865a1..334a15ba2459 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -190,24 +190,26 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
let summary = "Gets the specified extent from a shape";
let description = [{
Gets the extent indexed by `dim` from `shape`.
-
If the shape is an error, it returns an error size.
}];
let arguments = (ins
Shape_ShapeType:$shape,
- Confined<I64Attr, [IntNonNegative]>:$dim
+ Shape_SizeType:$dim
);
let results = (outs Shape_SizeType:$extent);
let assemblyFormat = "$shape `,` $dim attr-dict";
let builders = [
- // Builder that allows passing a simple integer instead of an IntegerAttr.
- OpBuilder<
- [{OpBuilder &builder, OperationState &result, Value shape, int64_t dim}],
- [{build(builder, result, shape, builder.getI64IntegerAttr(dim));}]
- >
+ // Builder that allows passing a constant dimension as a simple integer.
+ OpBuilder<"OpBuilder &builder, OperationState &result, Value shape, "
+ "int64_t dim">
];
+ let extraClassDeclaration = [{
+ /// Get the `dim` value as integer if it is constant.
+ Optional<int64_t> getConstantDim();
+ }];
+
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 9df20e4871c7..07913273b810 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -405,15 +405,31 @@ OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
// GetExtentOp
//===----------------------------------------------------------------------===//
+Optional<int64_t> GetExtentOp::getConstantDim() {
+ if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) {
+ return constSizeOp.value().getLimitedValue();
+ }
+ return llvm::None;
+}
+
OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
auto elements = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
if (!elements)
return nullptr;
- uint64_t dimToGet = dim().getLimitedValue();
- // TODO: Constant fold this to some kind of constant error.
- if (dimToGet >= (uint64_t)elements.getNumElements())
+ Optional<int64_t> dim = getConstantDim();
+ if (!dim.hasValue())
return nullptr;
- return elements.getValue({dimToGet});
+ if (dim.getValue() >= elements.getNumElements())
+ return nullptr;
+ return elements.getValue({(uint64_t)dim.getValue()});
+}
+
+void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
+ int64_t dim) {
+ auto loc = result.location;
+ auto dimAttr = builder.getIndexAttr(dim);
+ Value dimValue = builder.create<ConstSizeOp>(loc, dimAttr);
+ build(builder, result, shape, dimValue);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 7c90753e255e..6693e44d9b30 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1,6 +1,5 @@
// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize <%s | FileCheck %s --dump-input=fail
-// -----
// CHECK-LABEL: func @f
func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape {
// CHECK: shape.const_shape [2, 3, 4]
@@ -9,6 +8,7 @@ func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape {
}
// -----
+
// Basic case.
// CHECK-LABEL: func @f
func @f() -> (!shape.shape, !shape.shape) {
@@ -22,6 +22,7 @@ func @f() -> (!shape.shape, !shape.shape) {
}
// -----
+
// Negative split point.
// CHECK-LABEL: func @f
func @f() -> (!shape.shape, !shape.shape) {
@@ -34,6 +35,7 @@ func @f() -> (!shape.shape, !shape.shape) {
}
// -----
+
// Out of range split point. No folding.
// CHECK-LABEL: func @f
func @f() -> (!shape.shape, !shape.shape) {
@@ -45,6 +47,7 @@ func @f() -> (!shape.shape, !shape.shape) {
}
// -----
+
// Basic case.
// CHECK-LABEL: func @f
func @f() -> !shape.shape {
@@ -56,6 +59,7 @@ func @f() -> !shape.shape {
}
// -----
+
// Incompatible shapes. No folding.
// CHECK-LABEL: func @f
func @f() -> !shape.shape {
@@ -67,6 +71,7 @@ func @f() -> !shape.shape {
}
// -----
+
// Basic case.
// CHECK-LABEL: func @f
func @f() -> !shape.shape {
@@ -78,6 +83,7 @@ func @f() -> !shape.shape {
}
// -----
+
// Basic case.
// CHECK-LABEL: func @f
func @f() -> tensor<2xindex> {
@@ -88,6 +94,7 @@ func @f() -> tensor<2xindex> {
}
// -----
+
// Basic case.
// CHECK-LABEL: func @f()
func @f() -> !shape.shape {
@@ -99,6 +106,8 @@ func @f() -> !shape.shape {
return %ret : !shape.shape
}
+// -----
+
// CHECK-LABEL: func @no_fold
func @no_fold(%arg0: index) -> !shape.shape {
// CHECK-NOT: shape.const_shape
@@ -108,6 +117,7 @@ func @no_fold(%arg0: index) -> !shape.shape {
}
// -----
+
// Cast constant size to index and fold it away.
// CHECK-LABEL: func @const_size_to_index
func @const_size_to_index() -> index {
@@ -119,6 +129,7 @@ func @const_size_to_index() -> index {
}
// -----
+
// Cast constant index to size and fold it away.
// CHECK-LABEL: func @const_index_to_size
func @const_index_to_size() -> !shape.size {
@@ -130,6 +141,7 @@ func @const_index_to_size() -> !shape.size {
}
// -----
+
// Cast constant index to size, then back, and fold it away.
// CHECK-LABEL: func @const_index_to_size_to_index
func @const_index_to_size_to_index() -> index {
@@ -143,6 +155,7 @@ func @const_index_to_size_to_index() -> index {
}
// -----
+
// No folding.
// CHECK-LABEL: func @nonfoldable_size_to_index
func @nonfoldable_size_to_index(%cs : !shape.size) -> index {
@@ -152,6 +165,7 @@ func @nonfoldable_size_to_index(%cs : !shape.size) -> index {
}
// -----
+
// No folding.
// CHECK-LABEL: func @nonfoldable_index_to_size
func @nonfoldable_index_to_size(%ci : index) -> !shape.size {
@@ -161,6 +175,7 @@ func @nonfoldable_index_to_size(%ci : index) -> !shape.size {
}
// -----
+
// Fold number of elements computation.
// CHECK-LABEL: func @num_elements
func @num_elements() -> !shape.size {
@@ -174,6 +189,7 @@ func @num_elements() -> !shape.size {
}
// -----
+
// No folding.
// CHECK-LABEL: func @nonfoldable_num_elements
func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
@@ -184,32 +200,37 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
// -----
-// Canonicalization of shape.get_extent
-
// Basic folding.
// CHECK-LABEL: func @basic
func @basic() -> !shape.size {
// CHECK: shape.const_size 2
%0 = shape.const_shape [0, 1, 2]
- %1 = shape.get_extent %0, 2
+ %c2 = shape.const_size 2
+ %1 = shape.get_extent %0, %c2
return %1 : !shape.size
}
+// -----
+
// Should not fold.
// CHECK-LABEL: func @out_of_bounds
func @out_of_bounds() -> !shape.size {
// CHECK: shape.const_shape
// CHECK: shape.get_extent
%0 = shape.const_shape [0, 1, 2]
- %1 = shape.get_extent %0, 3
+ %c3 = shape.const_size 3
+ %1 = shape.get_extent %0, %c3
return %1 : !shape.size
}
+// -----
+
// Should not fold.
// CHECK-LABEL: func @not_const
func @not_const(%arg0: !shape.shape) -> !shape.size {
// CHECK: shape.get_extent
- %0 = shape.get_extent %arg0, 3
+ %c3 = shape.const_size 3
+ %0 = shape.get_extent %arg0, %c3
return %0 : !shape.size
}
@@ -270,6 +291,7 @@ func @f(%arg0: !shape.shape, %arg1: !shape.shape) {
}
// -----
+
// assuming_all with known passing witnesses can be folded
// CHECK-LABEL: func @f
func @f() {
@@ -285,6 +307,7 @@ func @f() {
}
// -----
+
// assuming_all should not be removed if not all witnesses are statically passing.
//
// Additionally check that the attribute is moved to the end as this op is
@@ -303,6 +326,7 @@ func @f() {
}
// -----
+
// any can be replaced with a constant input if it has one.
// CHECK-LABEL: func @f
func @f(%arg0 : !shape.shape) -> !shape.shape {
@@ -315,6 +339,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape {
// -----
+
// Folding of any with partially constant operands is not yet implemented.
// CHECK-LABEL: func @f
func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape {
@@ -325,6 +350,7 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape {
}
// -----
+
// assuming with a known passing witness can be removed
// CHECK-LABEL: func @f
func @f() {
@@ -341,6 +367,7 @@ func @f() {
}
// -----
+
// assuming without a known passing passing witness cannot be removed
// CHECK-LABEL: func @f
func @f() {
More information about the Mlir-commits
mailing list