[Mlir-commits] [mlir] 3147342 - [MLIR] Change custom printer/parser for loop.parallel and loop.reduce.
Alexander Belyaev
llvmlistbot at llvm.org
Mon Mar 9 07:12:32 PDT 2020
Author: Alexander Belyaev
Date: 2020-03-09T15:11:48+01:00
New Revision: 3147342ae7ef9470f879fd62bac6b0786a4f0d65
URL: https://github.com/llvm/llvm-project/commit/3147342ae7ef9470f879fd62bac6b0786a4f0d65
DIFF: https://github.com/llvm/llvm-project/commit/3147342ae7ef9470f879fd62bac6b0786a4f0d65.diff
LOG: [MLIR] Change custom printer/parser for loop.parallel and loop.reduce.
Added:
Modified:
mlir/include/mlir/Dialect/LoopOps/LoopOps.td
mlir/lib/Dialect/LoopOps/LoopOps.cpp
mlir/test/Conversion/convert-to-cfg.mlir
mlir/test/Dialect/Loops/invalid.mlir
mlir/test/Dialect/Loops/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
index 8850349af574..28b2e8c99392 100644
--- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
+++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
@@ -272,13 +272,13 @@ def ParallelOp : Loop_Op<"parallel",
For example:
```mlir
- loop.parallel (%iv) = (%lb) to (%ub) step (%step) {
+ loop.parallel (%iv) = (%lb) to (%ub) step (%step) -> f32 {
%zero = constant 0.0 : f32
- loop.reduce(%zero) {
+ loop.reduce(%zero) : f32 {
^bb0(%lhs : f32, %rhs: f32):
%res = addf %lhs, %rhs : f32
loop.reduce.return %res : f32
- } : f32
+ }
}
```
}];
diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp
index c0cb149bf815..9c28eec27eba 100644
--- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp
+++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp
@@ -407,7 +407,7 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
parser.resolveOperands(upper, builder.getIndexType(), result.operands))
return failure();
- // Parse step value.
+ // Parse step values.
SmallVector<OpAsmParser::OperandType, 4> steps;
if (parser.parseKeyword("step") ||
parser.parseOperandList(steps, ivs.size(),
@@ -415,7 +415,7 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
parser.resolveOperands(steps, builder.getIndexType(), result.operands))
return failure();
- // Parse step value.
+ // Parse init values.
SmallVector<OpAsmParser::OperandType, 4> initVals;
if (succeeded(parser.parseOptionalKeyword("init"))) {
if (parser.parseOperandList(initVals, /*requiredOperandCount=*/-1,
@@ -423,6 +423,10 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
return failure();
}
+ // Parse optional results in case there is a reduce.
+ if (parser.parseOptionalArrowTypeList(result.types))
+ return failure();
+
// Now parse the body.
Region *body = result.addRegion();
SmallVector<Type, 4> types(ivs.size(), builder.getIndexType());
@@ -437,9 +441,8 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
static_cast<int32_t>(steps.size()),
static_cast<int32_t>(initVals.size())}));
- // Parse attributes and optional results (in case there is a reduce).
- if (parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseOptionalColonTypeList(result.types))
+ // Parse attributes.
+ if (parser.parseOptionalAttrDict(result.attributes))
return failure();
if (!initVals.empty())
@@ -457,11 +460,10 @@ static void print(OpAsmPrinter &p, ParallelOp op) {
<< ")";
if (!op.initVals().empty())
p << " init (" << op.initVals() << ")";
+ p.printOptionalArrowTypeList(op.getResultTypes());
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(
op.getAttrs(), /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
- if (!op.results().empty())
- p << " : " << op.getResultTypes();
}
ParallelOp mlir::loop::getParallelForInductionVarOwner(Value val) {
@@ -515,24 +517,24 @@ static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
parser.parseRParen())
return failure();
- // Now parse the body.
- Region *body = result.addRegion();
- if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
- return failure();
-
- // And the type of the operand (and also what reduce computes on).
Type resultType;
+ // Parse the type of the operand (and also what reduce computes on).
if (parser.parseColonType(resultType) ||
parser.resolveOperand(operand, resultType, result.operands))
return failure();
+ // Now parse the body.
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
+ return failure();
+
return success();
}
static void print(OpAsmPrinter &p, ReduceOp op) {
p << op.getOperationName() << "(" << op.operand() << ") ";
- p.printRegion(op.reductionOperator());
p << " : " << op.operand().getType();
+ p.printRegion(op.reductionOperator());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/convert-to-cfg.mlir b/mlir/test/Conversion/convert-to-cfg.mlir
index 54c5d4c4a9cf..8a8a999d5ee9 100644
--- a/mlir/test/Conversion/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/convert-to-cfg.mlir
@@ -268,14 +268,14 @@ func @simple_parallel_reduce_loop(%arg0: index, %arg1: index,
// The continuation block has access to the (last value of) reduction.
// CHECK: ^[[CONTINUE]]:
// CHECK: return %[[ITER_ARG]]
- %0 = loop.parallel (%i) = (%arg0) to (%arg1) step (%arg2) init(%arg3) {
+ %0 = loop.parallel (%i) = (%arg0) to (%arg1) step (%arg2) init(%arg3) -> f32 {
%cst = constant 42.0 : f32
- loop.reduce(%cst) {
+ loop.reduce(%cst) : f32 {
^bb0(%lhs: f32, %rhs: f32):
%1 = mulf %lhs, %rhs : f32
loop.reduce.return %1 : f32
- } : f32
- } : f32
+ }
+ }
return %0 : f32
}
@@ -304,20 +304,20 @@ func @parallel_reduce_loop(%arg0 : index, %arg1 : index, %arg2 : index,
%step = constant 1 : index
%init = constant 42 : i64
%0:2 = loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
- step (%arg4, %step) init(%arg5, %init) {
+ step (%arg4, %step) init(%arg5, %init) -> (f32, i64) {
%cf = constant 42.0 : f32
- loop.reduce(%cf) {
+ loop.reduce(%cf) : f32 {
^bb0(%lhs: f32, %rhs: f32):
%1 = addf %lhs, %rhs : f32
loop.reduce.return %1 : f32
- } : f32
+ }
%2 = call @generate() : () -> i64
- loop.reduce(%2) {
+ loop.reduce(%2) : i64 {
^bb0(%lhs: i64, %rhs: i64):
%3 = or %lhs, %rhs : i64
loop.reduce.return %3 : i64
- } : i64
- } : f32, i64
+ }
+ }
return %0#0, %0#1 : f32, i64
}
diff --git a/mlir/test/Dialect/Loops/invalid.mlir b/mlir/test/Dialect/Loops/invalid.mlir
index 44075aca59af..6962387b946c 100644
--- a/mlir/test/Dialect/Loops/invalid.mlir
+++ b/mlir/test/Dialect/Loops/invalid.mlir
@@ -175,10 +175,10 @@ func @parallel_fewer_results_than_reduces(
// expected-error at +1 {{expects number of results: 0 to be the same as number of reductions: 1}}
loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) {
%c0 = constant 1.0 : f32
- loop.reduce(%c0) {
+ loop.reduce(%c0) : f32 {
^bb0(%lhs: f32, %rhs: f32):
loop.reduce.return %lhs : f32
- } : f32
+ }
}
return
}
@@ -189,8 +189,8 @@ func @parallel_more_results_than_reduces(
%arg0 : index, %arg1 : index, %arg2 : index) {
// expected-error at +2 {{expects number of results: 1 to be the same as number of reductions: 0}}
%zero = constant 1.0 : f32
- %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) init (%zero) {
- } : f32
+ %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) init (%zero) -> f32 {
+ }
return
}
@@ -200,13 +200,12 @@ func @parallel_more_results_than_reduces(
func @parallel_more_results_than_initial_values(
%arg0 : index, %arg1: index, %arg2: index) {
// expected-error at +1 {{expects number of results: 1 to be the same as number of initial values: 0}}
- %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) {
- loop.reduce(%arg0) {
+ %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) -> f32 {
+ loop.reduce(%arg0) : index {
^bb0(%lhs: index, %rhs: index):
loop.reduce.return %lhs : index
- } : index
- } : f32
- return
+ }
+ }
}
// -----
@@ -214,13 +213,14 @@ func @parallel_more_results_than_initial_values(
func @parallel_
diff erent_types_of_results_and_reduces(
%arg0 : index, %arg1: index, %arg2: index) {
%zero = constant 0.0 : f32
- %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) init (%zero) {
+ %res = loop.parallel (%i0) = (%arg0) to (%arg1)
+ step (%arg2) init (%zero) -> f32 {
// expected-error at +1 {{expects type of reduce: 'index' to be the same as result type: 'f32'}}
- loop.reduce(%arg0) {
+ loop.reduce(%arg0) : index {
^bb0(%lhs: index, %rhs: index):
loop.reduce.return %lhs : index
- } : index
- } : f32
+ }
+ }
return
}
@@ -228,10 +228,10 @@ func @parallel_
diff erent_types_of_results_and_reduces(
func @top_level_reduce(%arg0 : f32) {
// expected-error at +1 {{expects parent op 'loop.parallel'}}
- loop.reduce(%arg0) {
+ loop.reduce(%arg0) : f32 {
^bb0(%lhs : f32, %rhs : f32):
loop.reduce.return %lhs : f32
- } : f32
+ }
return
}
@@ -239,12 +239,13 @@ func @top_level_reduce(%arg0 : f32) {
func @reduce_empty_block(%arg0 : index, %arg1 : f32) {
%zero = constant 0.0 : f32
- %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) {
+ %res = loop.parallel (%i0) = (%arg0) to (%arg0)
+ step (%arg0) init (%zero) -> f32 {
// expected-error at +1 {{the block inside reduce should not be empty}}
- loop.reduce(%arg1) {
+ loop.reduce(%arg1) : f32 {
^bb0(%lhs : f32, %rhs : f32):
- } : f32
- } : f32
+ }
+ }
return
}
@@ -252,13 +253,14 @@ func @reduce_empty_block(%arg0 : index, %arg1 : f32) {
func @reduce_too_many_args(%arg0 : index, %arg1 : f32) {
%zero = constant 0.0 : f32
- %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) {
+ %res = loop.parallel (%i0) = (%arg0) to (%arg0)
+ step (%arg0) init (%zero) -> f32 {
// expected-error at +1 {{expects two arguments to reduce block of type 'f32'}}
- loop.reduce(%arg1) {
+ loop.reduce(%arg1) : f32 {
^bb0(%lhs : f32, %rhs : f32, %other : f32):
loop.reduce.return %lhs : f32
- } : f32
- } : f32
+ }
+ }
return
}
@@ -266,13 +268,14 @@ func @reduce_too_many_args(%arg0 : index, %arg1 : f32) {
func @reduce_wrong_args(%arg0 : index, %arg1 : f32) {
%zero = constant 0.0 : f32
- %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) {
+ %res = loop.parallel (%i0) = (%arg0) to (%arg0)
+ step (%arg0) init (%zero) -> f32 {
// expected-error at +1 {{expects two arguments to reduce block of type 'f32'}}
- loop.reduce(%arg1) {
+ loop.reduce(%arg1) : f32 {
^bb0(%lhs : f32, %rhs : i32):
loop.reduce.return %lhs : f32
- } : f32
- } : f32
+ }
+ }
return
}
@@ -281,13 +284,14 @@ func @reduce_wrong_args(%arg0 : index, %arg1 : f32) {
func @reduce_wrong_terminator(%arg0 : index, %arg1 : f32) {
%zero = constant 0.0 : f32
- %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) {
+ %res = loop.parallel (%i0) = (%arg0) to (%arg0)
+ step (%arg0) init (%zero) -> f32 {
// expected-error at +1 {{the block inside reduce should be terminated with a 'loop.reduce.return' op}}
- loop.reduce(%arg1) {
+ loop.reduce(%arg1) : f32 {
^bb0(%lhs : f32, %rhs : f32):
loop.yield
- } : f32
- } : f32
+ }
+ }
return
}
@@ -295,14 +299,15 @@ func @reduce_wrong_terminator(%arg0 : index, %arg1 : f32) {
func @reduceReturn_wrong_type(%arg0 : index, %arg1: f32) {
%zero = constant 0.0 : f32
- %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) {
- loop.reduce(%arg1) {
+ %res = loop.parallel (%i0) = (%arg0) to (%arg0)
+ step (%arg0) init (%zero) -> f32 {
+ loop.reduce(%arg1) : f32 {
^bb0(%lhs : f32, %rhs : f32):
%c0 = constant 1 : index
// expected-error at +1 {{needs to have type 'f32' (the type of the enclosing ReduceOp)}}
loop.reduce.return %c0 : index
- } : f32
- } : f32
+ }
+ }
return
}
@@ -349,7 +354,8 @@ func @std_for_operands_mismatch(%arg0 : index, %arg1 : index, %arg2 : index) {
%s0 = constant 0.0 : f32
%t0 = constant 1 : i32
// expected-error at +1 {{mismatch in number of loop-carried values and defined values}}
- %result1:3 = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %ti = %t0) -> (f32, i32, f32) {
+ %result1:3 = loop.for %i0 = %arg0 to %arg1 step %arg2
+ iter_args(%si = %s0, %ti = %t0) -> (f32, i32, f32) {
%sn = addf %si, %si : f32
%tn = addi %ti, %ti : i32
loop.yield %sn, %tn, %sn : f32, i32, f32
@@ -364,7 +370,8 @@ func @std_for_operands_mismatch_2(%arg0 : index, %arg1 : index, %arg2 : index) {
%t0 = constant 1 : i32
%u0 = constant 1.0 : f32
// expected-error at +1 {{mismatch in number of loop-carried values and defined values}}
- %result1:2 = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %ti = %t0, %ui = %u0) -> (f32, i32) {
+ %result1:2 = loop.for %i0 = %arg0 to %arg1 step %arg2
+ iter_args(%si = %s0, %ti = %t0, %ui = %u0) -> (f32, i32) {
%sn = addf %si, %si : f32
%tn = addi %ti, %ti : i32
%un = subf %ui, %ui : f32
@@ -379,8 +386,9 @@ func @std_for_operands_mismatch_3(%arg0 : index, %arg1 : index, %arg2 : index) {
// expected-note at +1 {{prior use here}}
%s0 = constant 0.0 : f32
%t0 = constant 1.0 : f32
- // expected-error at +1 {{expects
diff erent type than prior uses: 'i32' vs 'f32'}}
- %result1:2 = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %ti = %t0) -> (i32, i32) {
+ // expected-error at +2 {{expects
diff erent type than prior uses: 'i32' vs 'f32'}}
+ %result1:2 = loop.for %i0 = %arg0 to %arg1 step %arg2
+ iter_args(%si = %s0, %ti = %t0) -> (i32, i32) {
%sn = addf %si, %si : i32
%tn = addf %ti, %ti : i32
loop.yield %sn, %tn : i32, i32
diff --git a/mlir/test/Dialect/Loops/ops.mlir b/mlir/test/Dialect/Loops/ops.mlir
index 40aef314d273..881feb46ead4 100644
--- a/mlir/test/Dialect/Loops/ops.mlir
+++ b/mlir/test/Dialect/Loops/ops.mlir
@@ -60,14 +60,22 @@ func @std_parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
%max_cmp = cmpi "sge", %i0, %i1 : index
%max = select %max_cmp, %i0, %i1 : index
%zero = constant 0.0 : f32
- %red = loop.parallel (%i2) = (%min) to (%max) step (%i1) init (%zero) {
+ %int_zero = constant 0 : i32
+ %red:2 = loop.parallel (%i2) = (%min) to (%max) step (%i1)
+ init (%zero, %int_zero) -> (f32, i32) {
%one = constant 1.0 : f32
- loop.reduce(%one) {
+ loop.reduce(%one) : f32 {
^bb0(%lhs : f32, %rhs: f32):
%res = addf %lhs, %rhs : f32
loop.reduce.return %res : f32
- } : f32
- } : f32
+ }
+ %int_one = constant 1 : i32
+ loop.reduce(%int_one) : i32 {
+ ^bb0(%lhs : i32, %rhs: i32):
+ %res = muli %lhs, %rhs : i32
+ loop.reduce.return %res : i32
+ }
+ }
}
return
}
@@ -85,16 +93,24 @@ func @std_parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
// CHECK-NEXT: %[[MAX_CMP:.*]] = cmpi "sge", %[[I0]], %[[I1]] : index
// CHECK-NEXT: %[[MAX:.*]] = select %[[MAX_CMP]], %[[I0]], %[[I1]] : index
// CHECK-NEXT: %[[ZERO:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[INT_ZERO:.*]] = constant 0 : i32
// CHECK-NEXT: loop.parallel (%{{.*}}) = (%[[MIN]]) to (%[[MAX]])
-// CHECK-SAME: step (%[[I1]]) init (%[[ZERO]]) {
+// CHECK-SAME: step (%[[I1]])
+// CHECK-SAME: init (%[[ZERO]], %[[INT_ZERO]]) -> (f32, i32) {
// CHECK-NEXT: %[[ONE:.*]] = constant 1.000000e+00 : f32
-// CHECK-NEXT: loop.reduce(%[[ONE]]) {
+// CHECK-NEXT: loop.reduce(%[[ONE]]) : f32 {
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
// CHECK-NEXT: %[[RES:.*]] = addf %[[LHS]], %[[RHS]] : f32
// CHECK-NEXT: loop.reduce.return %[[RES]] : f32
-// CHECK-NEXT: } : f32
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[INT_ONE:.*]] = constant 1 : i32
+// CHECK-NEXT: loop.reduce(%[[INT_ONE]]) : i32 {
+// CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
+// CHECK-NEXT: %[[RES:.*]] = muli %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: loop.reduce.return %[[RES]] : i32
+// CHECK-NEXT: }
// CHECK-NEXT: loop.yield
-// CHECK-NEXT: } : f32
+// CHECK-NEXT: }
// CHECK-NEXT: loop.yield
func @parallel_explicit_yield(
More information about the Mlir-commits
mailing list