[Mlir-commits] [mlir] 3147342 - [MLIR] Change custom printer/parser for loop.parallel and loop.reduce.

Alexander Belyaev llvmlistbot at llvm.org
Mon Mar 9 07:12:32 PDT 2020


Author: Alexander Belyaev
Date: 2020-03-09T15:11:48+01:00
New Revision: 3147342ae7ef9470f879fd62bac6b0786a4f0d65

URL: https://github.com/llvm/llvm-project/commit/3147342ae7ef9470f879fd62bac6b0786a4f0d65
DIFF: https://github.com/llvm/llvm-project/commit/3147342ae7ef9470f879fd62bac6b0786a4f0d65.diff

LOG: [MLIR] Change custom printer/parser for loop.parallel and loop.reduce.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LoopOps/LoopOps.td
    mlir/lib/Dialect/LoopOps/LoopOps.cpp
    mlir/test/Conversion/convert-to-cfg.mlir
    mlir/test/Dialect/Loops/invalid.mlir
    mlir/test/Dialect/Loops/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
index 8850349af574..28b2e8c99392 100644
--- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
+++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
@@ -272,13 +272,13 @@ def ParallelOp : Loop_Op<"parallel",
     For example:
 
     ```mlir
-       loop.parallel (%iv) = (%lb) to (%ub) step (%step) {
+       loop.parallel (%iv) = (%lb) to (%ub) step (%step) -> f32 {
          %zero = constant 0.0 : f32
-         loop.reduce(%zero) {
+         loop.reduce(%zero) : f32 {
            ^bb0(%lhs : f32, %rhs: f32):
              %res = addf %lhs, %rhs : f32
              loop.reduce.return %res : f32
-         } : f32
+         }
        }
     ```
    }];

diff  --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp
index c0cb149bf815..9c28eec27eba 100644
--- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp
+++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp
@@ -407,7 +407,7 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
       parser.resolveOperands(upper, builder.getIndexType(), result.operands))
     return failure();
 
-  // Parse step value.
+  // Parse step values.
   SmallVector<OpAsmParser::OperandType, 4> steps;
   if (parser.parseKeyword("step") ||
       parser.parseOperandList(steps, ivs.size(),
@@ -415,7 +415,7 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
       parser.resolveOperands(steps, builder.getIndexType(), result.operands))
     return failure();
 
-  // Parse step value.
+  // Parse init values.
   SmallVector<OpAsmParser::OperandType, 4> initVals;
   if (succeeded(parser.parseOptionalKeyword("init"))) {
     if (parser.parseOperandList(initVals, /*requiredOperandCount=*/-1,
@@ -423,6 +423,10 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
       return failure();
   }
 
+  // Parse optional results in case there is a reduce.
+  if (parser.parseOptionalArrowTypeList(result.types))
+    return failure();
+
   // Now parse the body.
   Region *body = result.addRegion();
   SmallVector<Type, 4> types(ivs.size(), builder.getIndexType());
@@ -437,9 +441,8 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
                                 static_cast<int32_t>(steps.size()),
                                 static_cast<int32_t>(initVals.size())}));
 
-  // Parse attributes and optional results (in case there is a reduce).
-  if (parser.parseOptionalAttrDict(result.attributes) ||
-      parser.parseOptionalColonTypeList(result.types))
+  // Parse attributes.
+  if (parser.parseOptionalAttrDict(result.attributes))
     return failure();
 
   if (!initVals.empty())
@@ -457,11 +460,10 @@ static void print(OpAsmPrinter &p, ParallelOp op) {
     << ")";
   if (!op.initVals().empty())
     p << " init (" << op.initVals() << ")";
+  p.printOptionalArrowTypeList(op.getResultTypes());
   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
   p.printOptionalAttrDict(
       op.getAttrs(), /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
-  if (!op.results().empty())
-    p << " : " << op.getResultTypes();
 }
 
 ParallelOp mlir::loop::getParallelForInductionVarOwner(Value val) {
@@ -515,24 +517,24 @@ static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
       parser.parseRParen())
     return failure();
 
-  // Now parse the body.
-  Region *body = result.addRegion();
-  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
-    return failure();
-
-  // And the type of the operand (and also what reduce computes on).
   Type resultType;
+  // Parse the type of the operand (and also what reduce computes on).
   if (parser.parseColonType(resultType) ||
       parser.resolveOperand(operand, resultType, result.operands))
     return failure();
 
+  // Now parse the body.
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
+    return failure();
+
   return success();
 }
 
 static void print(OpAsmPrinter &p, ReduceOp op) {
   p << op.getOperationName() << "(" << op.operand() << ") ";
-  p.printRegion(op.reductionOperator());
   p << " : " << op.operand().getType();
+  p.printRegion(op.reductionOperator());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/convert-to-cfg.mlir b/mlir/test/Conversion/convert-to-cfg.mlir
index 54c5d4c4a9cf..8a8a999d5ee9 100644
--- a/mlir/test/Conversion/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/convert-to-cfg.mlir
@@ -268,14 +268,14 @@ func @simple_parallel_reduce_loop(%arg0: index, %arg1: index,
   // The continuation block has access to the (last value of) reduction.
   // CHECK: ^[[CONTINUE]]:
   // CHECK:   return %[[ITER_ARG]]
-  %0 = loop.parallel (%i) = (%arg0) to (%arg1) step (%arg2) init(%arg3) {
+  %0 = loop.parallel (%i) = (%arg0) to (%arg1) step (%arg2) init(%arg3) -> f32 {
     %cst = constant 42.0 : f32
-    loop.reduce(%cst) {
+    loop.reduce(%cst) : f32 {
     ^bb0(%lhs: f32, %rhs: f32):
       %1 = mulf %lhs, %rhs : f32
       loop.reduce.return %1 : f32
-    } : f32
-  } : f32
+    }
+  }
   return %0 : f32
 }
 
@@ -304,20 +304,20 @@ func @parallel_reduce_loop(%arg0 : index, %arg1 : index, %arg2 : index,
   %step = constant 1 : index
   %init = constant 42 : i64
   %0:2 = loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
-                       step (%arg4, %step) init(%arg5, %init) {
+                       step (%arg4, %step) init(%arg5, %init) -> (f32, i64) {
     %cf = constant 42.0 : f32
-    loop.reduce(%cf) {
+    loop.reduce(%cf) : f32 {
     ^bb0(%lhs: f32, %rhs: f32):
       %1 = addf %lhs, %rhs : f32
       loop.reduce.return %1 : f32
-    } : f32
+    }
 
     %2 = call @generate() : () -> i64
-    loop.reduce(%2) {
+    loop.reduce(%2) : i64 {
     ^bb0(%lhs: i64, %rhs: i64):
       %3 = or %lhs, %rhs : i64
       loop.reduce.return %3 : i64
-    } : i64
-  } : f32, i64
+    }
+  }
   return %0#0, %0#1 : f32, i64
 }

diff  --git a/mlir/test/Dialect/Loops/invalid.mlir b/mlir/test/Dialect/Loops/invalid.mlir
index 44075aca59af..6962387b946c 100644
--- a/mlir/test/Dialect/Loops/invalid.mlir
+++ b/mlir/test/Dialect/Loops/invalid.mlir
@@ -175,10 +175,10 @@ func @parallel_fewer_results_than_reduces(
   // expected-error at +1 {{expects number of results: 0 to be the same as number of reductions: 1}}
   loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) {
     %c0 = constant 1.0 : f32
-    loop.reduce(%c0) {
+    loop.reduce(%c0) : f32 {
       ^bb0(%lhs: f32, %rhs: f32):
         loop.reduce.return %lhs : f32
-    } : f32
+    }
   }
   return
 }
@@ -189,8 +189,8 @@ func @parallel_more_results_than_reduces(
     %arg0 : index, %arg1 : index, %arg2 : index) {
   // expected-error at +2 {{expects number of results: 1 to be the same as number of reductions: 0}}
   %zero = constant 1.0 : f32
-  %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) init (%zero) {
-  } : f32
+  %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) init (%zero) -> f32 {
+  }
 
   return
 }
@@ -200,13 +200,12 @@ func @parallel_more_results_than_reduces(
 func @parallel_more_results_than_initial_values(
     %arg0 : index, %arg1: index, %arg2: index) {
   // expected-error at +1 {{expects number of results: 1 to be the same as number of initial values: 0}}
-  %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) {
-    loop.reduce(%arg0) {
+  %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) -> f32 {
+    loop.reduce(%arg0) : index {
       ^bb0(%lhs: index, %rhs: index):
         loop.reduce.return %lhs : index
-    } : index
-  } : f32
-  return
+    }
+  }
 }
 
 // -----
@@ -214,13 +213,14 @@ func @parallel_more_results_than_initial_values(
 func @parallel_
diff erent_types_of_results_and_reduces(
     %arg0 : index, %arg1: index, %arg2: index) {
   %zero = constant 0.0 : f32
-  %res = loop.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) init (%zero) {
+  %res = loop.parallel (%i0) = (%arg0) to (%arg1)
+                                       step (%arg2) init (%zero) -> f32 {
     // expected-error at +1 {{expects type of reduce: 'index' to be the same as result type: 'f32'}}
-    loop.reduce(%arg0) {
+    loop.reduce(%arg0) : index {
       ^bb0(%lhs: index, %rhs: index):
         loop.reduce.return %lhs : index
-    } : index
-  } : f32
+    }
+  }
   return
 }
 
@@ -228,10 +228,10 @@ func @parallel_
diff erent_types_of_results_and_reduces(
 
 func @top_level_reduce(%arg0 : f32) {
   // expected-error at +1 {{expects parent op 'loop.parallel'}}
-  loop.reduce(%arg0) {
+  loop.reduce(%arg0) : f32 {
     ^bb0(%lhs : f32, %rhs : f32):
       loop.reduce.return %lhs : f32
-  } : f32
+  }
   return
 }
 
@@ -239,12 +239,13 @@ func @top_level_reduce(%arg0 : f32) {
 
 func @reduce_empty_block(%arg0 : index, %arg1 : f32) {
   %zero = constant 0.0 : f32
-  %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) {
+  %res = loop.parallel (%i0) = (%arg0) to (%arg0)
+                                       step (%arg0) init (%zero) -> f32 {
     // expected-error at +1 {{the block inside reduce should not be empty}}
-    loop.reduce(%arg1) {
+    loop.reduce(%arg1) : f32 {
       ^bb0(%lhs : f32, %rhs : f32):
-    } : f32
-  } : f32
+    }
+  }
   return
 }
 
@@ -252,13 +253,14 @@ func @reduce_empty_block(%arg0 : index, %arg1 : f32) {
 
 func @reduce_too_many_args(%arg0 : index, %arg1 : f32) {
   %zero = constant 0.0 : f32
-  %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) {
+  %res = loop.parallel (%i0) = (%arg0) to (%arg0)
+                                       step (%arg0) init (%zero) -> f32 {
     // expected-error at +1 {{expects two arguments to reduce block of type 'f32'}}
-    loop.reduce(%arg1) {
+    loop.reduce(%arg1) : f32 {
       ^bb0(%lhs : f32, %rhs : f32, %other : f32):
         loop.reduce.return %lhs : f32
-    } : f32
-  } : f32
+    }
+  }
   return
 }
 
@@ -266,13 +268,14 @@ func @reduce_too_many_args(%arg0 : index, %arg1 : f32) {
 
 func @reduce_wrong_args(%arg0 : index, %arg1 : f32) {
   %zero = constant 0.0 : f32
-  %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) {
+  %res = loop.parallel (%i0) = (%arg0) to (%arg0)
+                                       step (%arg0) init (%zero) -> f32 {
     // expected-error at +1 {{expects two arguments to reduce block of type 'f32'}}
-    loop.reduce(%arg1) {
+    loop.reduce(%arg1) : f32 {
       ^bb0(%lhs : f32, %rhs : i32):
         loop.reduce.return %lhs : f32
-    } : f32
-  } : f32
+    }
+  }
   return
 }
 
@@ -281,13 +284,14 @@ func @reduce_wrong_args(%arg0 : index, %arg1 : f32) {
 
 func @reduce_wrong_terminator(%arg0 : index, %arg1 : f32) {
   %zero = constant 0.0 : f32
-  %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) {
+  %res = loop.parallel (%i0) = (%arg0) to (%arg0)
+                                       step (%arg0) init (%zero) -> f32 {
     // expected-error at +1 {{the block inside reduce should be terminated with a 'loop.reduce.return' op}}
-    loop.reduce(%arg1) {
+    loop.reduce(%arg1) : f32 {
       ^bb0(%lhs : f32, %rhs : f32):
         loop.yield
-    } : f32
-  } : f32
+    }
+  }
   return
 }
 
@@ -295,14 +299,15 @@ func @reduce_wrong_terminator(%arg0 : index, %arg1 : f32) {
 
 func @reduceReturn_wrong_type(%arg0 : index, %arg1: f32) {
   %zero = constant 0.0 : f32
-  %res = loop.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) {
-    loop.reduce(%arg1) {
+  %res = loop.parallel (%i0) = (%arg0) to (%arg0)
+                                       step (%arg0) init (%zero) -> f32 {
+    loop.reduce(%arg1) : f32 {
       ^bb0(%lhs : f32, %rhs : f32):
         %c0 = constant 1 : index
         // expected-error at +1 {{needs to have type 'f32' (the type of the enclosing ReduceOp)}}
         loop.reduce.return %c0 : index
-    } : f32
-  } : f32
+    }
+  }
   return
 }
 
@@ -349,7 +354,8 @@ func @std_for_operands_mismatch(%arg0 : index, %arg1 : index, %arg2 : index) {
   %s0 = constant 0.0 : f32
   %t0 = constant 1 : i32
   // expected-error at +1 {{mismatch in number of loop-carried values and defined values}}
-  %result1:3 = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %ti = %t0) -> (f32, i32, f32) {
+  %result1:3 = loop.for %i0 = %arg0 to %arg1 step %arg2
+                    iter_args(%si = %s0, %ti = %t0) -> (f32, i32, f32) {
     %sn = addf %si, %si : f32
     %tn = addi %ti, %ti : i32
     loop.yield %sn, %tn, %sn : f32, i32, f32
@@ -364,7 +370,8 @@ func @std_for_operands_mismatch_2(%arg0 : index, %arg1 : index, %arg2 : index) {
   %t0 = constant 1 : i32
   %u0 = constant 1.0 : f32
   // expected-error at +1 {{mismatch in number of loop-carried values and defined values}}
-  %result1:2 = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %ti = %t0, %ui = %u0) -> (f32, i32) {
+  %result1:2 = loop.for %i0 = %arg0 to %arg1 step %arg2
+                    iter_args(%si = %s0, %ti = %t0, %ui = %u0) -> (f32, i32) {
     %sn = addf %si, %si : f32
     %tn = addi %ti, %ti : i32
     %un = subf %ui, %ui : f32
@@ -379,8 +386,9 @@ func @std_for_operands_mismatch_3(%arg0 : index, %arg1 : index, %arg2 : index) {
   // expected-note at +1 {{prior use here}}
   %s0 = constant 0.0 : f32
   %t0 = constant 1.0 : f32
-  // expected-error at +1 {{expects 
diff erent type than prior uses: 'i32' vs 'f32'}}
-  %result1:2 = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %ti = %t0) -> (i32, i32) {
+  // expected-error at +2 {{expects 
diff erent type than prior uses: 'i32' vs 'f32'}}
+  %result1:2 = loop.for %i0 = %arg0 to %arg1 step %arg2
+                    iter_args(%si = %s0, %ti = %t0) -> (i32, i32) {
     %sn = addf %si, %si : i32
     %tn = addf %ti, %ti : i32
     loop.yield %sn, %tn : i32, i32

diff  --git a/mlir/test/Dialect/Loops/ops.mlir b/mlir/test/Dialect/Loops/ops.mlir
index 40aef314d273..881feb46ead4 100644
--- a/mlir/test/Dialect/Loops/ops.mlir
+++ b/mlir/test/Dialect/Loops/ops.mlir
@@ -60,14 +60,22 @@ func @std_parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
     %max_cmp = cmpi "sge", %i0, %i1 : index
     %max = select %max_cmp, %i0, %i1 : index
     %zero = constant 0.0 : f32
-    %red = loop.parallel (%i2) = (%min) to (%max) step (%i1) init (%zero) {
+    %int_zero = constant 0 : i32
+    %red:2 = loop.parallel (%i2) = (%min) to (%max) step (%i1)
+                                      init (%zero, %int_zero) -> (f32, i32) {
       %one = constant 1.0 : f32
-      loop.reduce(%one) {
+      loop.reduce(%one) : f32 {
         ^bb0(%lhs : f32, %rhs: f32):
           %res = addf %lhs, %rhs : f32
           loop.reduce.return %res : f32
-      } : f32
-    } : f32
+      }
+      %int_one = constant 1 : i32
+      loop.reduce(%int_one) : i32 {
+        ^bb0(%lhs : i32, %rhs: i32):
+          %res = muli %lhs, %rhs : i32
+          loop.reduce.return %res : i32
+      }
+    }
   }
   return
 }
@@ -85,16 +93,24 @@ func @std_parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
 //  CHECK-NEXT:     %[[MAX_CMP:.*]] = cmpi "sge", %[[I0]], %[[I1]] : index
 //  CHECK-NEXT:     %[[MAX:.*]] = select %[[MAX_CMP]], %[[I0]], %[[I1]] : index
 //  CHECK-NEXT:     %[[ZERO:.*]] = constant 0.000000e+00 : f32
+//  CHECK-NEXT:     %[[INT_ZERO:.*]] = constant 0 : i32
 //  CHECK-NEXT:     loop.parallel (%{{.*}}) = (%[[MIN]]) to (%[[MAX]])
-//  CHECK-SAME:           step (%[[I1]]) init (%[[ZERO]]) {
+//  CHECK-SAME:          step (%[[I1]])
+//  CHECK-SAME:          init (%[[ZERO]], %[[INT_ZERO]]) -> (f32, i32) {
 //  CHECK-NEXT:       %[[ONE:.*]] = constant 1.000000e+00 : f32
-//  CHECK-NEXT:       loop.reduce(%[[ONE]]) {
+//  CHECK-NEXT:       loop.reduce(%[[ONE]]) : f32 {
 //  CHECK-NEXT:       ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
 //  CHECK-NEXT:         %[[RES:.*]] = addf %[[LHS]], %[[RHS]] : f32
 //  CHECK-NEXT:         loop.reduce.return %[[RES]] : f32
-//  CHECK-NEXT:       } : f32
+//  CHECK-NEXT:       }
+//  CHECK-NEXT:       %[[INT_ONE:.*]] = constant 1 : i32
+//  CHECK-NEXT:       loop.reduce(%[[INT_ONE]]) : i32 {
+//  CHECK-NEXT:       ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
+//  CHECK-NEXT:         %[[RES:.*]] = muli %[[LHS]], %[[RHS]] : i32
+//  CHECK-NEXT:         loop.reduce.return %[[RES]] : i32
+//  CHECK-NEXT:       }
 //  CHECK-NEXT:       loop.yield
-//  CHECK-NEXT:     } : f32
+//  CHECK-NEXT:     }
 //  CHECK-NEXT:     loop.yield
 
 func @parallel_explicit_yield(


        


More information about the Mlir-commits mailing list