[Mlir-commits] [mlir] 6d42953 - [mlir] Fix scf.for with signless iterations print/parse
Jacques Pienaar
llvmlistbot at llvm.org
Wed Mar 15 23:14:42 PDT 2023
Author: Jacques Pienaar
Date: 2023-03-16T02:14:05-04:00
New Revision: 6d42953f147dcd8a95c1bc1e565cba4d613ab0b0
URL: https://github.com/llvm/llvm-project/commit/6d42953f147dcd8a95c1bc1e565cba4d613ab0b0
DIFF: https://github.com/llvm/llvm-project/commit/6d42953f147dcd8a95c1bc1e565cba4d613ab0b0.diff
LOG: [mlir] Fix scf.for with signless iterations print/parse
There was accidentally a different form used for printing and parsing.
Added:
Modified:
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/test/Dialect/SCF/ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index f3fcd5ac20263..4e7bcc499be3d 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -434,29 +434,38 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
parser.parseOperand(ub) || parser.parseKeyword("step") ||
parser.parseOperand(step))
return failure();
- // Parse optional type, else assume Index.
- if (parser.parseOptionalColon())
- type = builder.getIndexType();
- else if (parser.parseType(type))
- return failure();
- inductionVariable.type = type;
- if (parser.resolveOperand(lb, type, result.operands) ||
- parser.resolveOperand(ub, type, result.operands) ||
- parser.resolveOperand(step, type, result.operands))
- return failure();
// Parse the optional initial iteration arguments.
SmallVector<OpAsmParser::Argument, 4> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
regionArgs.push_back(inductionVariable);
- if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
+ bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+ if (hasIterArgs) {
// Parse assignment list and results type list.
if (parser.parseAssignmentList(regionArgs, operands) ||
parser.parseArrowTypeList(result.types))
return failure();
+ }
+
+ if (regionArgs.size() != result.types.size() + 1)
+ return parser.emitError(
+ parser.getNameLoc(),
+ "mismatch in number of loop-carried values and defined values");
- // Resolve input operands.
+ // Parse optional type, else assume Index.
+ if (parser.parseOptionalColon())
+ type = builder.getIndexType();
+ else if (parser.parseType(type))
+ return failure();
+
+ // Resolve input operands.
+ regionArgs.front().type = type;
+ if (parser.resolveOperand(lb, type, result.operands) ||
+ parser.resolveOperand(ub, type, result.operands) ||
+ parser.resolveOperand(step, type, result.operands))
+ return failure();
+ if (hasIterArgs) {
for (auto argOperandType :
llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
Type type = std::get<2>(argOperandType);
@@ -467,11 +476,6 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
}
}
- if (regionArgs.size() != result.types.size() + 1)
- return parser.emitError(
- parser.getNameLoc(),
- "mismatch in number of loop-carried values and defined values");
-
// Parse the body region.
Region *body = result.addRegion();
if (parser.parseRegion(*body, regionArgs))
diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index 174ae539fe56c..46d175d6870ce 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -37,6 +37,18 @@ func.func @std_for_i32(%arg0 : i32, %arg1 : i32, %arg2 : i32) {
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} : i32 {
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} : i32 {
+func.func @scf_for_i64_iter(%arg1: i64, %arg2: i64) {
+ %c1_i64 = arith.constant 1 : i64
+ %c0_i64 = arith.constant 0 : i64
+ %0 = scf.for %arg3 = %arg1 to %arg2 step %c1_i64 iter_args(%arg4 = %c0_i64) -> (i64) : i64 {
+ %1 = arith.addi %arg4, %arg3 : i64
+ scf.yield %1 : i64
+ }
+ return
+}
+// CHECK-LABEL: scf_for_i64_iter
+// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}} -> (i64) : i64 {
+
func.func @std_if(%arg0: i1, %arg1: f32) {
scf.if %arg0 {
%0 = arith.addf %arg1, %arg1 : f32
More information about the Mlir-commits
mailing list