[Mlir-commits] [mlir] 0eb50e6 - [MLIR][Shape] Allow `shape.reduce` to operate on extent tensors
Frederik Gossen
llvmlistbot at llvm.org
Thu Jul 16 06:54:14 PDT 2020
Author: Frederik Gossen
Date: 2020-07-16T13:53:37Z
New Revision: 0eb50e614c65d189a3f1bdf747be973829046bc1
URL: https://github.com/llvm/llvm-project/commit/0eb50e614c65d189a3f1bdf747be973829046bc1
DIFF: https://github.com/llvm/llvm-project/commit/0eb50e614c65d189a3f1bdf747be973829046bc1.diff
LOG: [MLIR][Shape] Allow `shape.reduce` to operate on extent tensors
Allow `shape.reduce` to take both `shape.shape` and `tensor<?xindex>` as an
argument.
Differential Revision: https://reviews.llvm.org/D83943
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
mlir/test/Dialect/Shape/invalid.mlir
mlir/test/Dialect/Shape/ops.mlir
mlir/test/Dialect/Shape/shape-to-shape.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 1f141a2e705a..090b4c6f4abb 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -338,23 +338,26 @@ def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
def Shape_ReduceOp : Shape_Op<"reduce",
[SingleBlockImplicitTerminator<"YieldOp">]> {
- let summary = "Returns an expression reduced over a shape";
+ let summary = "Returns an expression reduced over a shape or extent tensor";
let description = [{
- An operation that takes as input a shape, number of initial values and has a
- region/function that is applied repeatedly for every dimension of the shape.
+ An operation that takes as input a shape or extent tensor, and a number of
+ initial values. This operation has a region/function that is applied
+ repeatedly for every extent of the input. Starting with the initial values,
+ the individual extents are then aggregated as defined by the associated
+ region.
Conceptually this op performs the following reduction:
```
res[] = init;
- for (int i = 0, e = shape.rank(); i != e; ++i) {
+ for (int i = 0, i < shape.rank(); i++) {
res = fn(i, shape[i], res[0], ..., res[n]);
}
```
- Where fn is provided by the user and the result of the reduce op is the
+ Where `fn` is provided by the user and the result of the reduce op is the
last computed output of the reduce function. As an example, computing the
- number of elements
+ number of elements can be defined as follows:
```mlir
func @reduce(%shape : !shape.shape, %init : !shape.size) -> !shape.size {
@@ -367,11 +370,10 @@ def Shape_ReduceOp : Shape_Op<"reduce",
return %num_elements : !shape.size
}
```
-
- If the shape is unranked, then the results of the op is also unranked.
}];
- let arguments = (ins Shape_ShapeType:$shape, Variadic<AnyType>:$initVals);
+ let arguments = (ins Shape_ShapeOrExtentTensorType:$shape,
+ Variadic<AnyType>:$initVals);
let results = (outs Variadic<AnyType>:$result);
let regions = (region SizedRegion<1>:$region);
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index a6f54053a326..b983968b124d 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -721,18 +721,31 @@ static LogicalResult verify(ReduceOp op) {
// Verify block arg types.
Block &block = op.region().front();
+ // The block takes index, extent, and aggregated values as arguments.
auto blockArgsCount = op.initVals().size() + 2;
if (block.getNumArguments() != blockArgsCount)
return op.emitOpError() << "ReduceOp body is expected to have "
<< blockArgsCount << " arguments";
- if (block.getArgument(0).getType() != IndexType::get(op.getContext()))
+ // The first block argument is the index and must always be of type `index`.
+ if (!block.getArgument(0).getType().isa<IndexType>())
return op.emitOpError(
"argument 0 of ReduceOp body is expected to be of IndexType");
- if (block.getArgument(1).getType() != SizeType::get(op.getContext()))
- return op.emitOpError(
- "argument 1 of ReduceOp body is expected to be of SizeType");
+ // The second block argument is the extent and must be of type `size` or
+ // `index`, depending on whether the reduce operation is applied to a shape or
+ // to an extent tensor.
+ Type extentTy = block.getArgument(1).getType();
+ if (op.shape().getType().isa<ShapeType>()) {
+ if (!extentTy.isa<SizeType>())
+ return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
+ "SizeType if the ReduceOp operates on a ShapeType");
+ } else {
+ if (!extentTy.isa<IndexType>())
+ return op.emitOpError(
+ "argument 1 of ReduceOp body is expected to be of IndexType if the "
+ "ReduceOp operates on an extent tensor");
+ }
for (auto type : llvm::enumerate(op.initVals()))
if (block.getArgument(type.index() + 2).getType() != type.value().getType())
@@ -743,17 +756,18 @@ static LogicalResult verify(ReduceOp op) {
}
static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
- auto *ctx = parser.getBuilder().getContext();
// Parse operands.
SmallVector<OpAsmParser::OperandType, 3> operands;
+ Type shapeOrExtentTensorType;
if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren) ||
+ parser.parseColonType(shapeOrExtentTensorType) ||
parser.parseOptionalArrowTypeList(result.types))
return failure();
// Resolve operands.
auto initVals = llvm::makeArrayRef(operands).drop_front();
- if (parser.resolveOperand(operands.front(), ShapeType::get(ctx),
+ if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
result.operands) ||
parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
result.operands))
@@ -773,7 +787,7 @@ static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
static void print(OpAsmPrinter &p, ReduceOp op) {
p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
- << ") ";
+ << ") : " << op.shape().getType();
p.printOptionalArrowTypeList(op.getResultTypes());
p.printRegion(op.region());
p.printOptionalAttrDict(op.getAttrs());
diff --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
index 1c214567c63a..9051054b3f18 100644
--- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
+++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
@@ -1,10 +1,10 @@
// RUN: mlir-opt -convert-shape-to-scf -split-input-file %s | FileCheck %s
-// CHECK-LABEL: shape_reduce
-// CHECK-SAME: [[SHAPE:%.*]]: !shape.shape) -> !shape.size {
+// CHECK-LABEL: @shape_reduce
+// CHECK-SAME: ([[SHAPE:%.*]]: !shape.shape) -> !shape.size
func @shape_reduce(%shape : !shape.shape) -> !shape.size {
%init = shape.const_size 1
- %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+ %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
^bb0(%index: index, %dim: !shape.size, %acc: !shape.size):
%new_acc = shape.mul %acc, %dim
shape.yield %new_acc : !shape.size
diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index da059a489be3..3aca3677c143 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -2,7 +2,7 @@
func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
// expected-error at +1 {{ReduceOp body is expected to have 3 arguments}}
- %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+ %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
^bb0(%index: index, %dim: !shape.size):
shape.yield %dim : !shape.size
}
@@ -12,7 +12,7 @@ func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) {
// expected-error at +1 {{argument 0 of ReduceOp body is expected to be of IndexType}}
- %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+ %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
^bb0(%index: f32, %dim: !shape.size, %acc: !shape.size):
%new_acc = "shape.add"(%acc, %dim)
: (!shape.size, !shape.size) -> !shape.size
@@ -23,8 +23,8 @@ func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) {
// -----
func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) {
- // expected-error at +1 {{argument 1 of ReduceOp body is expected to be of SizeType}}
- %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+ // expected-error at +1 {{argument 1 of ReduceOp body is expected to be of SizeType if the ReduceOp operates on a ShapeType}}
+ %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
^bb0(%index: index, %dim: f32, %lci: !shape.size):
shape.yield
}
@@ -32,9 +32,19 @@ func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) {
// -----
+func @reduce_op_arg1_wrong_type(%shape : tensor<?xindex>, %init : index) {
+ // expected-error at +1 {{argument 1 of ReduceOp body is expected to be of IndexType if the ReduceOp operates on an extent tensor}}
+ %num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
+ ^bb0(%index: index, %dim: f32, %lci: index):
+ shape.yield
+ }
+}
+
+// -----
+
func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) {
// expected-error at +1 {{type mismatch between argument 2 of ReduceOp body and initial value 0}}
- %num_elements = shape.reduce(%shape, %init) -> f32 {
+ %num_elements = shape.reduce(%shape, %init) : !shape.shape -> f32 {
^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
shape.yield
}
@@ -44,7 +54,7 @@ func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) {
func @yield_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
// expected-error at +3 {{number of operands does not match number of results of its parent}}
- %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+ %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
shape.yield %dim, %dim : !shape.size, !shape.size
}
@@ -54,7 +64,7 @@ func @yield_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
func @yield_op_type_mismatch(%shape : !shape.shape, %init : !shape.size) {
// expected-error at +4 {{types mismatch between yield op and its parent}}
- %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+ %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
%c0 = constant 1 : index
shape.yield %c0 : index
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 3a0bcf713073..c6f52519ad2e 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -6,15 +6,26 @@
// CHECK-LABEL: shape_num_elements
func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
- %init = shape.const_size 0
- %num_elements = shape.reduce(%shape, %init) -> !shape.size {
- ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
- %acc = shape.add %lci, %dim
- shape.yield %acc : !shape.size
+ %init = shape.const_size 1
+ %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
+ ^bb0(%index : index, %extent : !shape.size, %acc : !shape.size):
+ %acc_next = shape.mul %acc, %extent
+ shape.yield %acc_next : !shape.size
}
return %num_elements : !shape.size
}
+// CHECK-LABEL: extent_tensor_num_elements
+func @extent_tensor_num_elements(%shape : tensor<?xindex>) -> index {
+ %init = constant 1 : index
+ %num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
+ ^bb0(%index : index, %extent : index, %acc : index):
+ %acc_next = muli %acc, %extent : index
+ shape.yield %acc_next : index
+ }
+ return %num_elements : index
+}
+
func @test_shape_num_elements_unknown() {
%0 = "shape.unknown_shape"() : () -> !shape.shape
%1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
diff --git a/mlir/test/Dialect/Shape/shape-to-shape.mlir b/mlir/test/Dialect/Shape/shape-to-shape.mlir
index b3be4c9de3a1..9a75f0b9ca1b 100644
--- a/mlir/test/Dialect/Shape/shape-to-shape.mlir
+++ b/mlir/test/Dialect/Shape/shape-to-shape.mlir
@@ -1,16 +1,16 @@
// RUN: mlir-opt -shape-to-shape-lowering -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @num_elements_to_reduce(
-// CHECK-SAME: [[ARG:%.*]]: !shape.shape) -> [[SIZE_TY:!.*]] {
+// CHECK-SAME: [[ARG:%.*]]: !shape.shape) -> !shape.size {
func @num_elements_to_reduce(%shape : !shape.shape) -> !shape.size {
%num_elements = shape.num_elements %shape
return %num_elements : !shape.size
}
// CHECK: [[C1:%.*]] = shape.const_size 1
-// CHECK: [[NUM_ELEMENTS:%.*]] = shape.reduce([[ARG]], [[C1]]) -> [[SIZE_TY]]
-// CHECK: ^bb0({{.*}}: index, [[DIM:%.*]]: [[SIZE_TY]], [[ACC:%.*]]: [[SIZE_TY]]
+// CHECK: [[NUM_ELEMENTS:%.*]] = shape.reduce([[ARG]], [[C1]]) : !shape.shape -> !shape.size
+// CHECK: ^bb0({{.*}}: index, [[DIM:%.*]]: !shape.size, [[ACC:%.*]]: !shape.size
// CHECK: [[NEW_ACC:%.*]] = shape.mul [[DIM]], [[ACC]]
-// CHECK: shape.yield [[NEW_ACC]] : [[SIZE_TY]]
+// CHECK: shape.yield [[NEW_ACC]] : !shape.size
// CHECK: }
-// CHECK: return [[NUM_ELEMENTS]] : [[SIZE_TY]]
+// CHECK: return [[NUM_ELEMENTS]] : !shape.size
More information about the Mlir-commits
mailing list