[Mlir-commits] [mlir] 0eb50e6 - [MLIR][Shape] Allow `shape.reduce` to operate on extent tensors

Frederik Gossen llvmlistbot at llvm.org
Thu Jul 16 06:54:14 PDT 2020


Author: Frederik Gossen
Date: 2020-07-16T13:53:37Z
New Revision: 0eb50e614c65d189a3f1bdf747be973829046bc1

URL: https://github.com/llvm/llvm-project/commit/0eb50e614c65d189a3f1bdf747be973829046bc1
DIFF: https://github.com/llvm/llvm-project/commit/0eb50e614c65d189a3f1bdf747be973829046bc1.diff

LOG: [MLIR][Shape] Allow `shape.reduce` to operate on extent tensors

Allow `shape.reduce` to take both `shape.shape` and `tensor<?xindex>` as an
argument.

Differential Revision: https://reviews.llvm.org/D83943

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
    mlir/test/Dialect/Shape/invalid.mlir
    mlir/test/Dialect/Shape/ops.mlir
    mlir/test/Dialect/Shape/shape-to-shape.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 1f141a2e705a..090b4c6f4abb 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -338,23 +338,26 @@ def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
 
 def Shape_ReduceOp : Shape_Op<"reduce",
     [SingleBlockImplicitTerminator<"YieldOp">]> {
-  let summary = "Returns an expression reduced over a shape";
+  let summary = "Returns an expression reduced over a shape or extent tensor";
   let description = [{
-    An operation that takes as input a shape, number of initial values and has a
-    region/function that is applied repeatedly for every dimension of the shape.
+    An operation that takes as input a shape or extent tensor, and a number of
+    initial values. This operation has a region/function that is applied
+    repeatedly for every extent of the input. Starting with the initial values,
+    the individual extents are then aggregated as defined by the associated
+    region.
 
     Conceptually this op performs the following reduction:
 
     ```
     res[] = init;
-    for (int i = 0, e = shape.rank(); i != e; ++i) {
+    for (int i = 0, i < shape.rank(); i++) {
       res = fn(i, shape[i], res[0], ..., res[n]);
     }
     ```
 
-    Where fn is provided by the user and the result of the reduce op is the
+    Where `fn` is provided by the user and the result of the reduce op is the
     last computed output of the reduce function. As an example, computing the
-    number of elements
+    number of elements can be defined as follows:
 
     ```mlir
     func @reduce(%shape : !shape.shape, %init : !shape.size) -> !shape.size {
@@ -367,11 +370,10 @@ def Shape_ReduceOp : Shape_Op<"reduce",
       return %num_elements : !shape.size
     }
     ```
-
-    If the shape is unranked, then the results of the op is also unranked.
   }];
 
-  let arguments = (ins Shape_ShapeType:$shape, Variadic<AnyType>:$initVals);
+  let arguments = (ins Shape_ShapeOrExtentTensorType:$shape,
+                       Variadic<AnyType>:$initVals);
   let results = (outs Variadic<AnyType>:$result);
   let regions = (region SizedRegion<1>:$region);
 

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index a6f54053a326..b983968b124d 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -721,18 +721,31 @@ static LogicalResult verify(ReduceOp op) {
   // Verify block arg types.
   Block &block = op.region().front();
 
+  // The block takes index, extent, and aggregated values as arguments.
   auto blockArgsCount = op.initVals().size() + 2;
   if (block.getNumArguments() != blockArgsCount)
     return op.emitOpError() << "ReduceOp body is expected to have "
                             << blockArgsCount << " arguments";
 
-  if (block.getArgument(0).getType() != IndexType::get(op.getContext()))
+  // The first block argument is the index and must always be of type `index`.
+  if (!block.getArgument(0).getType().isa<IndexType>())
     return op.emitOpError(
         "argument 0 of ReduceOp body is expected to be of IndexType");
 
-  if (block.getArgument(1).getType() != SizeType::get(op.getContext()))
-    return op.emitOpError(
-        "argument 1 of ReduceOp body is expected to be of SizeType");
+  // The second block argument is the extent and must be of type `size` or
+  // `index`, depending on whether the reduce operation is applied to a shape or
+  // to an extent tensor.
+  Type extentTy = block.getArgument(1).getType();
+  if (op.shape().getType().isa<ShapeType>()) {
+    if (!extentTy.isa<SizeType>())
+      return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
+                            "SizeType if the ReduceOp operates on a ShapeType");
+  } else {
+    if (!extentTy.isa<IndexType>())
+      return op.emitOpError(
+          "argument 1 of ReduceOp body is expected to be of IndexType if the "
+          "ReduceOp operates on an extent tensor");
+  }
 
   for (auto type : llvm::enumerate(op.initVals()))
     if (block.getArgument(type.index() + 2).getType() != type.value().getType())
@@ -743,17 +756,18 @@ static LogicalResult verify(ReduceOp op) {
 }
 
 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
-  auto *ctx = parser.getBuilder().getContext();
   // Parse operands.
   SmallVector<OpAsmParser::OperandType, 3> operands;
+  Type shapeOrExtentTensorType;
   if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
                               OpAsmParser::Delimiter::Paren) ||
+      parser.parseColonType(shapeOrExtentTensorType) ||
       parser.parseOptionalArrowTypeList(result.types))
     return failure();
 
   // Resolve operands.
   auto initVals = llvm::makeArrayRef(operands).drop_front();
-  if (parser.resolveOperand(operands.front(), ShapeType::get(ctx),
+  if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
                             result.operands) ||
       parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
                              result.operands))
@@ -773,7 +787,7 @@ static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
 
 static void print(OpAsmPrinter &p, ReduceOp op) {
   p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
-    << ") ";
+    << ") : " << op.shape().getType();
   p.printOptionalArrowTypeList(op.getResultTypes());
   p.printRegion(op.region());
   p.printOptionalAttrDict(op.getAttrs());

diff  --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
index 1c214567c63a..9051054b3f18 100644
--- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
+++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
@@ -1,10 +1,10 @@
 // RUN: mlir-opt -convert-shape-to-scf -split-input-file %s | FileCheck %s
 
-// CHECK-LABEL: shape_reduce
-// CHECK-SAME:   [[SHAPE:%.*]]: !shape.shape) -> !shape.size {
+// CHECK-LABEL: @shape_reduce
+// CHECK-SAME:  ([[SHAPE:%.*]]: !shape.shape) -> !shape.size
 func @shape_reduce(%shape : !shape.shape) -> !shape.size {
   %init = shape.const_size 1
-  %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+  %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
     ^bb0(%index: index, %dim: !shape.size, %acc: !shape.size):
       %new_acc = shape.mul %acc, %dim
       shape.yield %new_acc : !shape.size

diff  --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index da059a489be3..3aca3677c143 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -2,7 +2,7 @@
 
 func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
   // expected-error at +1 {{ReduceOp body is expected to have 3 arguments}}
-  %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+  %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
     ^bb0(%index: index, %dim: !shape.size):
       shape.yield %dim : !shape.size
   }
@@ -12,7 +12,7 @@ func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
 
 func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) {
   // expected-error at +1 {{argument 0 of ReduceOp body is expected to be of IndexType}}
-  %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+  %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
     ^bb0(%index: f32, %dim: !shape.size, %acc: !shape.size):
       %new_acc = "shape.add"(%acc, %dim)
           : (!shape.size, !shape.size) -> !shape.size
@@ -23,8 +23,8 @@ func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) {
 // -----
 
 func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) {
-  // expected-error at +1 {{argument 1 of ReduceOp body is expected to be of SizeType}}
-  %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+  // expected-error at +1 {{argument 1 of ReduceOp body is expected to be of SizeType if the ReduceOp operates on a ShapeType}}
+  %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
     ^bb0(%index: index, %dim: f32, %lci: !shape.size):
       shape.yield
   }
@@ -32,9 +32,19 @@ func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) {
 
 // -----
 
+func @reduce_op_arg1_wrong_type(%shape : tensor<?xindex>, %init : index) {
+  // expected-error at +1 {{argument 1 of ReduceOp body is expected to be of IndexType if the ReduceOp operates on an extent tensor}}
+  %num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
+    ^bb0(%index: index, %dim: f32, %lci: index):
+      shape.yield
+  }
+}
+
+// -----
+
 func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) {
   // expected-error at +1 {{type mismatch between argument 2 of ReduceOp body and initial value 0}}
-  %num_elements = shape.reduce(%shape, %init) -> f32 {
+  %num_elements = shape.reduce(%shape, %init) : !shape.shape -> f32 {
     ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
       shape.yield
   }
@@ -44,7 +54,7 @@ func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) {
 
 func @yield_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
   // expected-error at +3 {{number of operands does not match number of results of its parent}}
-  %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+  %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
     ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
       shape.yield %dim, %dim : !shape.size, !shape.size
   }
@@ -54,7 +64,7 @@ func @yield_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
 
 func @yield_op_type_mismatch(%shape : !shape.shape, %init : !shape.size) {
   // expected-error at +4 {{types mismatch between yield op and its parent}}
-  %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+  %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
     ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
       %c0 = constant 1 : index
       shape.yield %c0 : index

diff  --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir
index 3a0bcf713073..c6f52519ad2e 100644
--- a/mlir/test/Dialect/Shape/ops.mlir
+++ b/mlir/test/Dialect/Shape/ops.mlir
@@ -6,15 +6,26 @@
 
 // CHECK-LABEL: shape_num_elements
 func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
-  %init = shape.const_size 0
-  %num_elements = shape.reduce(%shape, %init) -> !shape.size {
-    ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
-      %acc = shape.add %lci, %dim
-      shape.yield %acc : !shape.size
+  %init = shape.const_size 1
+  %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
+    ^bb0(%index : index, %extent : !shape.size, %acc : !shape.size):
+      %acc_next = shape.mul %acc, %extent
+      shape.yield %acc_next : !shape.size
   }
   return %num_elements : !shape.size
 }
 
+// CHECK-LABEL: extent_tensor_num_elements
+func @extent_tensor_num_elements(%shape : tensor<?xindex>) -> index {
+  %init = constant 1 : index
+  %num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
+    ^bb0(%index : index, %extent : index, %acc : index):
+      %acc_next = muli %acc, %extent : index
+      shape.yield %acc_next : index
+  }
+  return %num_elements : index
+}
+
 func @test_shape_num_elements_unknown() {
   %0 = "shape.unknown_shape"() : () -> !shape.shape
   %1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)

diff  --git a/mlir/test/Dialect/Shape/shape-to-shape.mlir b/mlir/test/Dialect/Shape/shape-to-shape.mlir
index b3be4c9de3a1..9a75f0b9ca1b 100644
--- a/mlir/test/Dialect/Shape/shape-to-shape.mlir
+++ b/mlir/test/Dialect/Shape/shape-to-shape.mlir
@@ -1,16 +1,16 @@
 // RUN: mlir-opt -shape-to-shape-lowering -split-input-file %s | FileCheck %s
 
 // CHECK-LABEL: func @num_elements_to_reduce(
-// CHECK-SAME:    [[ARG:%.*]]: !shape.shape) -> [[SIZE_TY:!.*]] {
+// CHECK-SAME:    [[ARG:%.*]]: !shape.shape) -> !shape.size {
 func @num_elements_to_reduce(%shape : !shape.shape) -> !shape.size {
   %num_elements = shape.num_elements %shape
   return %num_elements : !shape.size
 }
 // CHECK: [[C1:%.*]] = shape.const_size 1
-// CHECK: [[NUM_ELEMENTS:%.*]] = shape.reduce([[ARG]], [[C1]])  -> [[SIZE_TY]]
-// CHECK: ^bb0({{.*}}: index, [[DIM:%.*]]: [[SIZE_TY]], [[ACC:%.*]]: [[SIZE_TY]]
+// CHECK: [[NUM_ELEMENTS:%.*]] = shape.reduce([[ARG]], [[C1]]) : !shape.shape -> !shape.size
+// CHECK: ^bb0({{.*}}: index, [[DIM:%.*]]: !shape.size, [[ACC:%.*]]: !shape.size
 // CHECK:   [[NEW_ACC:%.*]] = shape.mul [[DIM]], [[ACC]]
-// CHECK:   shape.yield [[NEW_ACC]] : [[SIZE_TY]]
+// CHECK:   shape.yield [[NEW_ACC]] : !shape.size
 // CHECK: }
-// CHECK: return [[NUM_ELEMENTS]] : [[SIZE_TY]]
+// CHECK: return [[NUM_ELEMENTS]] : !shape.size
 


        


More information about the Mlir-commits mailing list