[Mlir-commits] [mlir] 04fb2b6 - [Mlir] Implement printer, parser, verifier and builder for shape.reduce.
Alexander Belyaev
llvmlistbot at llvm.org
Fri Jun 5 02:26:18 PDT 2020
Author: Alexander Belyaev
Date: 2020-06-05T11:25:32+02:00
New Revision: 04fb2b6123ee66e09b1956ff68b5436fe43cd3b4
URL: https://github.com/llvm/llvm-project/commit/04fb2b6123ee66e09b1956ff68b5436fe43cd3b4
DIFF: https://github.com/llvm/llvm-project/commit/04fb2b6123ee66e09b1956ff68b5436fe43cd3b4.diff
LOG: [Mlir] Implement printer, parser, verifier and builder for shape.reduce.
Differential Revision: https://reviews.llvm.org/D81186
Added:
mlir/test/Dialect/Shape/invalid.mlir
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 5fc2aa4fa2d6..ac5bedf3d6e3 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -290,7 +290,8 @@ def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
let hasFolder = 1;
}
-def Shape_ReduceOp : Shape_Op<"reduce", []> {
+def Shape_ReduceOp : Shape_Op<"reduce",
+ [SingleBlockImplicitTerminator<"YieldOp">]> {
let summary = "Returns an expression reduced over a shape";
let description = [{
An operation that takes as input a shape, number of initial values and has a
@@ -310,25 +311,32 @@ def Shape_ReduceOp : Shape_Op<"reduce", []> {
number of elements
```mlir
- func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
- %0 = "shape.constant_dim"() {value = 1 : i32} : () -> !shape.size
- %1 = "shape.reduce"(%shape, %0) ( {
- ^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size):
- %acc = "shape.mul"(%lci, %dim) :
+ func @reduce(%shape : !shape.shape, %init : !shape.size) -> !shape.size {
+ %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+ ^bb0(%index: index, %dim: !shape.size, %acc: !shape.size):
+ %updated_acc = "shape.mul"(%acc, %dim) :
(!shape.size, !shape.size) -> !shape.size
- shape.yield %acc : !shape.size
- }) : (!shape.shape, !shape.size) -> (!shape.size)
- return %1 : !shape.size
+ shape.yield %updated_acc : !shape.size
+ }
+ return %num_elements : !shape.size
}
```
If the shape is unranked, then the results of the op is also unranked.
}];
- let arguments = (ins Shape_ShapeType:$shape, Variadic<AnyType>:$args);
+ let arguments = (ins Shape_ShapeType:$shape, Variadic<AnyType>:$initVals);
let results = (outs Variadic<AnyType>:$result);
-
let regions = (region SizedRegion<1>:$body);
+
+ let builders = [
+ OpBuilder<"OpBuilder &builder, OperationState &result, "
+ "Value shape, ValueRange initVals">,
+ ];
+
+ let verifier = [{ return ::verify(*this); }];
+ let printer = [{ return ::print(p, *this); }];
+ let parser = [{ return ::parse$cppClass(parser, result); }];
}
def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index ed89d5bca19a..04b1a51e986e 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -481,6 +481,89 @@ OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
return DenseIntElementsAttr::get(type, shape);
}
+//===----------------------------------------------------------------------===//
+// ReduceOp
+//===----------------------------------------------------------------------===//
+
+void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
+ ValueRange initVals) {
+ result.addOperands(shape);
+ result.addOperands(initVals);
+
+ Region *bodyRegion = result.addRegion();
+ bodyRegion->push_back(new Block);
+ Block &bodyBlock = bodyRegion->front();
+ bodyBlock.addArgument(builder.getIndexType());
+ bodyBlock.addArgument(SizeType::get(builder.getContext()));
+
+ for (Type initValType : initVals.getTypes()) {
+ bodyBlock.addArgument(initValType);
+ result.addTypes(initValType);
+ }
+}
+
+static LogicalResult verify(ReduceOp op) {
+ // Verify block arg types.
+ Block &block = op.body().front();
+
+ auto blockArgsCount = op.initVals().size() + 2;
+ if (block.getNumArguments() != blockArgsCount)
+ return op.emitOpError() << "ReduceOp body is expected to have "
+ << blockArgsCount << " arguments";
+
+ if (block.getArgument(0).getType() != IndexType::get(op.getContext()))
+ return op.emitOpError(
+ "argument 0 of ReduceOp body is expected to be of IndexType");
+
+ if (block.getArgument(1).getType() != SizeType::get(op.getContext()))
+ return op.emitOpError(
+ "argument 1 of ReduceOp body is expected to be of SizeType");
+
+ for (auto type : llvm::enumerate(op.initVals()))
+ if (block.getArgument(type.index() + 2).getType() != type.value().getType())
+ return op.emitOpError()
+ << "type mismatch between argument " << type.index() + 2
+ << " of ReduceOp body and initial value " << type.index();
+ return success();
+}
+
+static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
+ auto *ctx = parser.getBuilder().getContext();
+ // Parse operands.
+ SmallVector<OpAsmParser::OperandType, 3> operands;
+ if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
+ OpAsmParser::Delimiter::Paren) ||
+ parser.parseOptionalArrowTypeList(result.types))
+ return failure();
+
+ // Resolve operands.
+ auto initVals = llvm::makeArrayRef(operands).drop_front();
+ if (parser.resolveOperand(operands.front(), ShapeType::get(ctx),
+ result.operands) ||
+ parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
+ result.operands))
+ return failure();
+
+ // Parse the body.
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
+ return failure();
+
+ // Parse attributes.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ return success();
+}
+
+static void print(OpAsmPrinter &p, ReduceOp op) {
+ p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
+ << ") ";
+ p.printOptionalArrowTypeList(op.getResultTypes());
+ p.printRegion(op.body());
+ p.printOptionalAttrDict(op.getAttrs());
+}
+
namespace mlir {
namespace shape {
diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
new file mode 100644
index 000000000000..63589c80e221
--- /dev/null
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
+ // expected-error at +1 {{ReduceOp body is expected to have 3 arguments}}
+ %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+ ^bb0(%index: index, %dim: !shape.size):
+ "shape.yield"(%dim) : (!shape.size) -> ()
+ }
+}
+
+// -----
+
+func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) {
+ // expected-error at +1 {{argument 0 of ReduceOp body is expected to be of IndexType}}
+ %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+ ^bb0(%index: f32, %dim: !shape.size, %lci: !shape.size):
+ %acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size
+ "shape.yield"(%acc) : (!shape.size) -> ()
+ }
+}
+
+// -----
+
+func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) {
+ // expected-error at +1 {{argument 1 of ReduceOp body is expected to be of SizeType}}
+ %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+ ^bb0(%index: index, %dim: f32, %lci: !shape.size):
+ "shape.yield"() : () -> ()
+ }
+}
+
+// -----
+
+func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) {
+ // expected-error at +1 {{type mismatch between argument 2 of ReduceOp body and initial value 0}}
+ %num_elements = shape.reduce(%shape, %init) -> f32 {
+ ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
+ "shape.yield"() : () -> ()
+ }
+}
diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 5f316d9988b8..0df58eddc643 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -6,13 +6,13 @@
// CHECK-LABEL: shape_num_elements
func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
- %0 = shape.const_size 0
- %1 = "shape.reduce"(%shape, %0) ( {
- ^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size):
+ %init = shape.const_size 0
+ %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+ ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
%acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size
"shape.yield"(%acc) : (!shape.size) -> ()
- }) : (!shape.shape, !shape.size) -> (!shape.size)
- return %1 : !shape.size
+ }
+ return %num_elements : !shape.size
}
func @test_shape_num_elements_unknown() {
More information about the Mlir-commits
mailing list