[Mlir-commits] [mlir] 6d42953 - [mlir] Fix scf.for with signless iterations print/parse

Jacques Pienaar llvmlistbot at llvm.org
Wed Mar 15 23:14:42 PDT 2023


Author: Jacques Pienaar
Date: 2023-03-16T02:14:05-04:00
New Revision: 6d42953f147dcd8a95c1bc1e565cba4d613ab0b0

URL: https://github.com/llvm/llvm-project/commit/6d42953f147dcd8a95c1bc1e565cba4d613ab0b0
DIFF: https://github.com/llvm/llvm-project/commit/6d42953f147dcd8a95c1bc1e565cba4d613ab0b0.diff

LOG: [mlir] Fix scf.for with signless iterations print/parse

There was accidentally a different form used for printing and parsing.

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/test/Dialect/SCF/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index f3fcd5ac20263..4e7bcc499be3d 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -434,29 +434,38 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
       parser.parseOperand(ub) || parser.parseKeyword("step") ||
       parser.parseOperand(step))
     return failure();
-  // Parse optional type, else assume Index.
-  if (parser.parseOptionalColon())
-    type = builder.getIndexType();
-  else if (parser.parseType(type))
-    return failure();
-  inductionVariable.type = type;
-  if (parser.resolveOperand(lb, type, result.operands) ||
-      parser.resolveOperand(ub, type, result.operands) ||
-      parser.resolveOperand(step, type, result.operands))
-    return failure();
 
   // Parse the optional initial iteration arguments.
   SmallVector<OpAsmParser::Argument, 4> regionArgs;
   SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
   regionArgs.push_back(inductionVariable);
 
-  if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
+  bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+  if (hasIterArgs) {
     // Parse assignment list and results type list.
     if (parser.parseAssignmentList(regionArgs, operands) ||
         parser.parseArrowTypeList(result.types))
       return failure();
+  }
+
+  if (regionArgs.size() != result.types.size() + 1)
+    return parser.emitError(
+        parser.getNameLoc(),
+        "mismatch in number of loop-carried values and defined values");
 
-    // Resolve input operands.
+  // Parse optional type, else assume Index.
+  if (parser.parseOptionalColon())
+    type = builder.getIndexType();
+  else if (parser.parseType(type))
+    return failure();
+
+  // Resolve input operands.
+  regionArgs.front().type = type;
+  if (parser.resolveOperand(lb, type, result.operands) ||
+      parser.resolveOperand(ub, type, result.operands) ||
+      parser.resolveOperand(step, type, result.operands))
+    return failure();
+  if (hasIterArgs) {
     for (auto argOperandType :
          llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
       Type type = std::get<2>(argOperandType);
@@ -467,11 +476,6 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
     }
   }
 
-  if (regionArgs.size() != result.types.size() + 1)
-    return parser.emitError(
-        parser.getNameLoc(),
-        "mismatch in number of loop-carried values and defined values");
-
   // Parse the body region.
   Region *body = result.addRegion();
   if (parser.parseRegion(*body, regionArgs))

diff  --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index 174ae539fe56c..46d175d6870ce 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -37,6 +37,18 @@ func.func @std_for_i32(%arg0 : i32, %arg1 : i32, %arg2 : i32) {
 //  CHECK-NEXT:   scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} : i32 {
 //  CHECK-NEXT:     scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} : i32 {
 
+func.func @scf_for_i64_iter(%arg1: i64, %arg2: i64) {
+  %c1_i64 = arith.constant 1 : i64
+  %c0_i64 = arith.constant 0 : i64
+  %0 = scf.for %arg3 = %arg1 to %arg2 step %c1_i64 iter_args(%arg4 = %c0_i64) -> (i64) : i64 {
+    %1 = arith.addi %arg4, %arg3 : i64
+    scf.yield %1 : i64
+  }
+  return
+}
+// CHECK-LABEL: scf_for_i64_iter
+// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}} -> (i64) : i64 {
+
 func.func @std_if(%arg0: i1, %arg1: f32) {
   scf.if %arg0 {
     %0 = arith.addf %arg1, %arg1 : f32


        


More information about the Mlir-commits mailing list