[Mlir-commits] [mlir] bbda411 - [mlir][scf] Relax type requirement on for

Jacques Pienaar llvmlistbot at llvm.org
Sun Mar 5 07:38:05 PST 2023


Author: Jacques Pienaar
Date: 2023-03-05T07:37:58-08:00
New Revision: bbda411f0e395a1727e243b62b331004cc8d0c30

URL: https://github.com/llvm/llvm-project/commit/bbda411f0e395a1727e243b62b331004cc8d0c30
DIFF: https://github.com/llvm/llvm-project/commit/bbda411f0e395a1727e243b62b331004cc8d0c30.diff

LOG: [mlir][scf] Relax type requirement on for

scf.for loop was restricted to only operate on Index type since
splitting out from affine.for. Relax requirement to allow for signless
integer types additionally. This allows specifying explicitly different
bitwidths for different loops as well as specialize from index to iN
while still using scf.for.

Differential Revision: https://reviews.llvm.org/D145288

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/test/Dialect/SCF/invalid.mlir
    mlir/test/Dialect/SCF/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 96ec62d27df6f..7d129caa3084a 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -121,50 +121,56 @@ def ForOp : SCF_Op<"for",
       [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
        ["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
         "getSingleUpperBound"]>,
+       AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,
        DeclareOpInterfaceMethods<RegionBranchOpInterface>,
        SingleBlockImplicitTerminator<"scf::YieldOp">,
        RecursiveMemoryEffects]> {
   let summary = "for operation";
   let description = [{
-    The "scf.for" operation represents a loop taking 3 SSA value as operands
+    The `scf.for` operation represents a loop taking 3 SSA value as operands
     that represent the lower bound, upper bound and step respectively. The
     operation defines an SSA value for its induction variable. It has one
     region capturing the loop body. The induction variable is represented as an
-    argument of this region. This SSA value always has type index, which is the
-    size of the machine word. The step is a value of type index, required to be
-    positive.
-    The lower and upper bounds specify a half-open range: the range includes
-    the lower bound but does not include the upper bound.
+    argument of this region. This SSA value is a signless integer or index.
+    The step is a value of same type but required to be positive. The lower and
+    upper bounds specify a half-open range: the range includes the lower bound
+    but does not include the upper bound.
 
     The body region must contain exactly one block that terminates with
-    "scf.yield". Calling ForOp::build will create such a region and insert
+    `scf.yield`. Calling ForOp::build will create such a region and insert
     the terminator implicitly if none is defined, so will the parsing even in
     cases when it is absent from the custom format. For example:
 
     ```mlir
+    // Index case.
     scf.for %iv = %lb to %ub step %step {
       ... // body
     }
+    ...
+    // Integer case.
+    scf.for %iv_32 = %lb_32 to %ub_32 step %step_32 : i32 {
+      ... // body
+    }
     ```
 
     `scf.for` can also operate on loop-carried variables and returns the final
     values after loop termination. The initial values of the variables are
-    passed as additional SSA operands to the "scf.for" following the 3 loop
+    passed as additional SSA operands to the `scf.for` following the 3 loop
     control SSA values mentioned above (lower bound, upper bound and step). The
     operation region has an argument for the induction variable, followed by
     one argument for each loop-carried variable, representing the value of the
     variable at the current iteration.
 
-    The region must terminate with a "scf.yield" that passes the current
+    The region must terminate with a `scf.yield` that passes the current
     values of all loop-carried variables to the next iteration, or to the
-    "scf.for" result, if at the last iteration. The static type of a
+    `scf.for` result, if at the last iteration. The static type of a
     loop-carried variable may not change with iterations; its runtime type is
     allowed to change. Note, that when the loop-carried variables are present,
     calling ForOp::build will not insert the terminator implicitly. The caller
-    must insert "scf.yield" in that case.
+    must insert `scf.yield` in that case.
 
-    "scf.for" results hold the final values after the last iteration.
+    `scf.for` results hold the final values after the last iteration.
     For example, to sum-reduce a memref:
 
     ```mlir
@@ -185,11 +191,11 @@ def ForOp : SCF_Op<"for",
     }
     ```
 
-    If the "scf.for" defines any values, a yield must be explicitly present.
-    The number and types of the "scf.for" results must match the initial
-    values in the "iter_args" binding and the yield operands.
+    If the `scf.for` defines any values, a yield must be explicitly present.
+    The number and types of the `scf.for` results must match the initial
+    values in the `iter_args` binding and the yield operands.
 
-    Another example with a nested "scf.if" (see "scf.if" for details) to
+    Another example with a nested `scf.if` (see `scf.if` for details) to
     perform conditional reduction:
 
     ```mlir
@@ -213,9 +219,9 @@ def ForOp : SCF_Op<"for",
     }
     ```
   }];
-  let arguments = (ins Index:$lowerBound,
-                       Index:$upperBound,
-                       Index:$step,
+  let arguments = (ins AnySignlessIntegerOrIndex:$lowerBound,
+                       AnySignlessIntegerOrIndex:$upperBound,
+                       AnySignlessIntegerOrIndex:$step,
                        Variadic<AnyType>:$initArgs);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 2abe18087107e..8852b79b7b7ee 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -297,10 +297,11 @@ void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
   result.addOperands(iterArgs);
   for (Value v : iterArgs)
     result.addTypes(v.getType());
+  Type t = lb.getType();
   Region *bodyRegion = result.addRegion();
   bodyRegion->push_back(new Block);
   Block &bodyBlock = bodyRegion->front();
-  bodyBlock.addArgument(builder.getIndexType(), result.location);
+  bodyBlock.addArgument(t, result.location);
   for (Value v : iterArgs)
     bodyBlock.addArgument(v.getType(), v.getLoc());
 
@@ -337,11 +338,9 @@ LogicalResult ForOp::verify() {
 LogicalResult ForOp::verifyRegions() {
   // Check that the body defines as single block argument for the induction
   // variable.
-  auto *body = getBody();
-  if (!body->getArgument(0).getType().isIndex())
+  if (getInductionVar().getType() != getLowerBound().getType())
     return emitOpError(
-        "expected body first argument to be an index argument for "
-        "the induction variable");
+        "expected induction variable to be same type as bounds and step");
 
   auto opNumResults = getNumResults();
   if (opNumResults == 0)
@@ -363,7 +362,7 @@ LogicalResult ForOp::verifyRegions() {
       return emitOpError() << "types mismatch between " << i
                            << "th iter region arg and defined value";
 
-    i++;
+    ++i;
   }
   return success();
 }
@@ -413,6 +412,8 @@ void ForOp::print(OpAsmPrinter &p) {
   if (!getIterOperands().empty())
     p << " -> (" << getIterOperands().getTypes() << ')';
   p << ' ';
+  if (Type t = getInductionVar().getType(); !t.isIndex())
+    p << " : " << t << ' ';
   p.printRegion(getRegion(),
                 /*printEntryBlockArgs=*/false,
                 /*printBlockTerminators=*/hasIterOperands());
@@ -421,21 +422,27 @@ void ForOp::print(OpAsmPrinter &p) {
 
 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
   auto &builder = parser.getBuilder();
-  Type indexType = builder.getIndexType();
+  Type type;
 
   OpAsmParser::Argument inductionVariable;
-  inductionVariable.type = indexType;
   OpAsmParser::UnresolvedOperand lb, ub, step;
 
   // Parse the induction variable followed by '='.
-  if (parser.parseArgument(inductionVariable) || parser.parseEqual() ||
+  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
       // Parse loop bounds.
-      parser.parseOperand(lb) ||
-      parser.resolveOperand(lb, indexType, result.operands) ||
-      parser.parseKeyword("to") || parser.parseOperand(ub) ||
-      parser.resolveOperand(ub, indexType, result.operands) ||
-      parser.parseKeyword("step") || parser.parseOperand(step) ||
-      parser.resolveOperand(step, indexType, result.operands))
+      parser.parseOperand(lb) || parser.parseKeyword("to") ||
+      parser.parseOperand(ub) || parser.parseKeyword("step") ||
+      parser.parseOperand(step))
+    return failure();
+  // Parse optional type, else assume Index.
+  if (parser.parseOptionalColon())
+    type = builder.getIndexType();
+  else if (parser.parseType(type))
+    return failure();
+  inductionVariable.type = type;
+  if (parser.resolveOperand(lb, type, result.operands) ||
+      parser.resolveOperand(ub, type, result.operands) ||
+      parser.resolveOperand(step, type, result.operands))
     return failure();
 
   // Parse the optional initial iteration arguments.

diff  --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index c3c396ec808f2..8566943ef8012 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics
 
 func.func @loop_for_lb(%arg0: f32, %arg1: index) {
-  // expected-error at +1 {{operand #0 must be index}}
+  // expected-error at +1 {{operand #0 must be signless integer or index}}
   "scf.for"(%arg0, %arg1, %arg1) ({}) : (f32, index, index) -> ()
   return
 }
@@ -9,7 +9,7 @@ func.func @loop_for_lb(%arg0: f32, %arg1: index) {
 // -----
 
 func.func @loop_for_ub(%arg0: f32, %arg1: index) {
-  // expected-error at +1 {{operand #1 must be index}}
+  // expected-error at +1 {{operand #1 must be signless integer or index}}
   "scf.for"(%arg1, %arg0, %arg1) ({}) : (index, f32, index) -> ()
   return
 }
@@ -17,13 +17,21 @@ func.func @loop_for_ub(%arg0: f32, %arg1: index) {
 // -----
 
 func.func @loop_for_step(%arg0: f32, %arg1: index) {
-  // expected-error at +1 {{operand #2 must be index}}
+  // expected-error at +1 {{operand #2 must be signless integer or index}}
   "scf.for"(%arg1, %arg1, %arg0) ({}) : (index, index, f32) -> ()
   return
 }
 
 // -----
 
+func.func @loop_for_mismatch(%arg0: i32, %arg1: index) {
+  // expected-error at +1 {{all of {lowerBound, upperBound, step} have same type}}
+  "scf.for"(%arg1, %arg0, %arg1) ({}) : (index, i32, index) -> ()
+  return
+}
+
+// -----
+
 func.func @loop_for_step_positive(%arg0: index) {
   // expected-error at +2 {{constant step operand must be positive}}
   %c0 = arith.constant 0 : index
@@ -63,7 +71,7 @@ func.func @loop_for_single_block(%arg0: index) {
 // -----
 
 func.func @loop_for_single_index_argument(%arg0: index) {
-  // expected-error at +1 {{op expected body first argument to be an index argument for the induction variable}}
+  // expected-error at +1 {{expected induction variable to be same type as bounds}}
   "scf.for"(%arg0, %arg0, %arg0) (
     {
     ^bb0(%i0 : f32):

diff  --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index 2314516fa1079..174ae539fe56c 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -26,6 +26,17 @@ func.func @std_for(%arg0 : index, %arg1 : index, %arg2 : index) {
 //  CHECK-NEXT:       %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : index
 //  CHECK-NEXT:       scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
 
+func.func @std_for_i32(%arg0 : i32, %arg1 : i32, %arg2 : i32) {
+  scf.for %i0 = %arg0 to %arg1 step %arg2 : i32 {
+    scf.for %i1 = %arg0 to %arg1 step %arg2 : i32 {
+    }
+  }
+  return
+}
+// CHECK-LABEL: func @std_for_i32(
+//  CHECK-NEXT:   scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} : i32 {
+//  CHECK-NEXT:     scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} : i32 {
+
 func.func @std_if(%arg0: i1, %arg1: f32) {
   scf.if %arg0 {
     %0 = arith.addf %arg1, %arg1 : f32


        


More information about the Mlir-commits mailing list