[Mlir-commits] [mlir] 04fb2b6 - [Mlir] Implement printer, parser, verifier and builder for shape.reduce.

Alexander Belyaev llvmlistbot at llvm.org
Fri Jun 5 02:26:18 PDT 2020


Author: Alexander Belyaev
Date: 2020-06-05T11:25:32+02:00
New Revision: 04fb2b6123ee66e09b1956ff68b5436fe43cd3b4

URL: https://github.com/llvm/llvm-project/commit/04fb2b6123ee66e09b1956ff68b5436fe43cd3b4
DIFF: https://github.com/llvm/llvm-project/commit/04fb2b6123ee66e09b1956ff68b5436fe43cd3b4.diff

LOG: [Mlir] Implement printer, parser, verifier and builder for shape.reduce.

Differential Revision: https://reviews.llvm.org/D81186

Added: 
    mlir/test/Dialect/Shape/invalid.mlir

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 5fc2aa4fa2d6..ac5bedf3d6e3 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -290,7 +290,8 @@ def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
   let hasFolder = 1;
 }
 
-def Shape_ReduceOp : Shape_Op<"reduce", []> {
+def Shape_ReduceOp : Shape_Op<"reduce",
+    [SingleBlockImplicitTerminator<"YieldOp">]> {
   let summary = "Returns an expression reduced over a shape";
   let description = [{
     An operation that takes as input a shape, number of initial values and has a
@@ -310,25 +311,32 @@ def Shape_ReduceOp : Shape_Op<"reduce", []> {
     number of elements
 
     ```mlir
-    func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
-      %0 = "shape.constant_dim"() {value = 1 : i32} : () -> !shape.size
-      %1 = "shape.reduce"(%shape, %0) ( {
-        ^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size):
-          %acc = "shape.mul"(%lci, %dim) :
+    func @reduce(%shape : !shape.shape, %init : !shape.size) -> !shape.size {
+      %num_elements = shape.reduce(%shape, %init) -> !shape.size  {
+        ^bb0(%index: index, %dim: !shape.size, %acc: !shape.size):
+          %updated_acc = "shape.mul"(%acc, %dim) :
             (!shape.size, !shape.size) -> !shape.size
-          shape.yield %acc : !shape.size
-        }) : (!shape.shape, !shape.size) -> (!shape.size)
-      return %1 : !shape.size
+          shape.yield %updated_acc : !shape.size
+      }
+      return %num_elements : !shape.size
     }
     ```
 
     If the shape is unranked, then the results of the op is also unranked.
   }];
 
-  let arguments = (ins Shape_ShapeType:$shape, Variadic<AnyType>:$args);
+  let arguments = (ins Shape_ShapeType:$shape, Variadic<AnyType>:$initVals);
   let results = (outs Variadic<AnyType>:$result);
-
   let regions = (region SizedRegion<1>:$body);
+
+  let builders = [
+    OpBuilder<"OpBuilder &builder, OperationState &result, "
+              "Value shape, ValueRange initVals">,
+  ];
+
+  let verifier = [{ return ::verify(*this); }];
+  let printer = [{ return ::print(p, *this); }];
+  let parser = [{ return ::parse$cppClass(parser, result); }];
 }
 
 def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index ed89d5bca19a..04b1a51e986e 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -481,6 +481,89 @@ OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
   return DenseIntElementsAttr::get(type, shape);
 }
 
+//===----------------------------------------------------------------------===//
+// ReduceOp
+//===----------------------------------------------------------------------===//
+
+void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
+                     ValueRange initVals) {
+  result.addOperands(shape);
+  result.addOperands(initVals);
+
+  Region *bodyRegion = result.addRegion();
+  bodyRegion->push_back(new Block);
+  Block &bodyBlock = bodyRegion->front();
+  bodyBlock.addArgument(builder.getIndexType());
+  bodyBlock.addArgument(SizeType::get(builder.getContext()));
+
+  for (Type initValType : initVals.getTypes()) {
+    bodyBlock.addArgument(initValType);
+    result.addTypes(initValType);
+  }
+}
+
+static LogicalResult verify(ReduceOp op) {
+  // Verify block arg types.
+  Block &block = op.body().front();
+
+  auto blockArgsCount = op.initVals().size() + 2;
+  if (block.getNumArguments() != blockArgsCount)
+    return op.emitOpError() << "ReduceOp body is expected to have "
+                            << blockArgsCount << " arguments";
+
+  if (block.getArgument(0).getType() != IndexType::get(op.getContext()))
+    return op.emitOpError(
+        "argument 0 of ReduceOp body is expected to be of IndexType");
+
+  if (block.getArgument(1).getType() != SizeType::get(op.getContext()))
+    return op.emitOpError(
+        "argument 1 of ReduceOp body is expected to be of SizeType");
+
+  for (auto type : llvm::enumerate(op.initVals()))
+    if (block.getArgument(type.index() + 2).getType() != type.value().getType())
+      return op.emitOpError()
+             << "type mismatch between argument " << type.index() + 2
+             << " of ReduceOp body and initial value " << type.index();
+  return success();
+}
+
+static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
+  auto *ctx = parser.getBuilder().getContext();
+  // Parse operands.
+  SmallVector<OpAsmParser::OperandType, 3> operands;
+  if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
+                              OpAsmParser::Delimiter::Paren) ||
+      parser.parseOptionalArrowTypeList(result.types))
+    return failure();
+
+  // Resolve operands.
+  auto initVals = llvm::makeArrayRef(operands).drop_front();
+  if (parser.resolveOperand(operands.front(), ShapeType::get(ctx),
+                            result.operands) ||
+      parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
+                             result.operands))
+    return failure();
+
+  // Parse the body.
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
+    return failure();
+
+  // Parse attributes.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  return success();
+}
+
+static void print(OpAsmPrinter &p, ReduceOp op) {
+  p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
+    << ") ";
+  p.printOptionalArrowTypeList(op.getResultTypes());
+  p.printRegion(op.body());
+  p.printOptionalAttrDict(op.getAttrs());
+}
+
 namespace mlir {
 namespace shape {
 

diff  --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
new file mode 100644
index 000000000000..63589c80e221
--- /dev/null
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
+  // expected-error at +1 {{ReduceOp body is expected to have 3 arguments}}
+  %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+    ^bb0(%index: index, %dim: !shape.size):
+      "shape.yield"(%dim) : (!shape.size) -> ()
+  }
+}
+
+// -----
+
+func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) {
+  // expected-error at +1 {{argument 0 of ReduceOp body is expected to be of IndexType}}
+  %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+    ^bb0(%index: f32, %dim: !shape.size, %lci: !shape.size):
+      %acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size
+      "shape.yield"(%acc) : (!shape.size) -> ()
+  }
+}
+
+// -----
+
+func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) {
+  // expected-error at +1 {{argument 1 of ReduceOp body is expected to be of SizeType}}
+  %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+    ^bb0(%index: index, %dim: f32, %lci: !shape.size):
+      "shape.yield"() : () -> ()
+  }
+}
+
+// -----
+
+func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) {
+  // expected-error at +1 {{type mismatch between argument 2 of ReduceOp body and initial value 0}}
+  %num_elements = shape.reduce(%shape, %init) -> f32 {
+    ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
+      "shape.yield"() : () -> ()
+  }
+}

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 5f316d9988b8..0df58eddc643 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -6,13 +6,13 @@
 
 // CHECK-LABEL: shape_num_elements
 func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
-  %0 = shape.const_size 0
-  %1 = "shape.reduce"(%shape, %0) ( {
-    ^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size):
+  %init = shape.const_size 0
+  %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+    ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
       %acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size
       "shape.yield"(%acc) : (!shape.size) -> ()
-    }) : (!shape.shape, !shape.size) -> (!shape.size)
-  return %1 : !shape.size
+  }
+  return %num_elements : !shape.size
 }
 
 func @test_shape_num_elements_unknown() {


        


More information about the Mlir-commits mailing list