[Mlir-commits] [mlir] b716bf8 - [mlir][scf] Fix builder of WhileOp with region builder arguments.

Ingo Müller llvmlistbot at llvm.org
Tue Feb 7 05:41:00 PST 2023


Author: Ingo Müller
Date: 2023-02-07T13:40:54Z
New Revision: b716bf84eaba25e0f83d1778288f65a671e85f98

URL: https://github.com/llvm/llvm-project/commit/b716bf84eaba25e0f83d1778288f65a671e85f98
DIFF: https://github.com/llvm/llvm-project/commit/b716bf84eaba25e0f83d1778288f65a671e85f98.diff

LOG: [mlir][scf] Fix builder of WhileOp with region builder arguments.

The overload of WhileOp::build with arguments for builder functions for
the regions of the op was broken: It did not compute correctly the types
(and locations) of the region arguments, which lead to failed assertions
when the result types were different from the operand types.
Specifically, it used the result types (and operand locations) for *both*
regions, instead of the operand types (and locations) for the 'before'
region and the result types (and loecations) for the 'after' region.

Reviewed By: Mogball, mehdi_amini

Differential Revision: https://reviews.llvm.org/D142952

Added: 
    mlir/test/Dialect/SCF/while-op-builder.mlir
    mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp

Modified: 
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/test/lib/Dialect/SCF/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 15e1a68d50ab4..6a4da00bbad36 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -2764,19 +2764,24 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
 
   OpBuilder::InsertionGuard guard(odsBuilder);
 
-  SmallVector<Location, 4> blockArgLocs;
+  // Build before region.
+  SmallVector<Location, 4> beforeArgLocs;
+  beforeArgLocs.reserve(operands.size());
   for (Value operand : operands) {
-    blockArgLocs.push_back(operand.getLoc());
+    beforeArgLocs.push_back(operand.getLoc());
   }
 
   Region *beforeRegion = odsState.addRegion();
-  Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{},
-                                              resultTypes, blockArgLocs);
+  Block *beforeBlock = odsBuilder.createBlock(
+      beforeRegion, /*insertPt=*/{}, operands.getTypes(), beforeArgLocs);
   beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments());
 
+  // Build after region.
+  SmallVector<Location, 4> afterArgLocs(resultTypes.size(), odsState.location);
+
   Region *afterRegion = odsState.addRegion();
   Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{},
-                                             resultTypes, blockArgLocs);
+                                             resultTypes, afterArgLocs);
   afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
 }
 

diff  --git a/mlir/test/Dialect/SCF/while-op-builder.mlir b/mlir/test/Dialect/SCF/while-op-builder.mlir
new file mode 100644
index 0000000000000..e96d8d074211b
--- /dev/null
+++ b/mlir/test/Dialect/SCF/while-op-builder.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -test-scf-while-op-builder | FileCheck %s
+
+// CHECK-LABEL: @testMatchingTypes
+func.func @testMatchingTypes(%arg0 : i32) {
+  %0 = scf.while (%arg1 = %arg0) : (i32) -> (i32) {
+    %c10 = arith.constant 10 : i32
+    %1 = arith.cmpi slt, %arg1, %c10 : i32
+    scf.condition(%1) %arg1 : i32
+  } do {
+  ^bb0(%arg1: i32):
+    scf.yield %arg1 : i32
+  }
+  // Expect the same loop twice (the dummy added by the test pass and the
+  // original one).
+  // CHECK: %[[V0:.*]] = scf.while (%[[arg1:.*]] = %[[arg0:.*]]) : (i32) -> i32 {
+  // CHECK: %[[V1:.*]] = scf.while (%[[arg2:.*]] = %[[arg0]]) : (i32) -> i32 {
+  return
+}
+
+// CHECK-LABEL: @testNonMatchingTypes
+func.func @testNonMatchingTypes(%arg0 : i32) {
+  %c1 = arith.constant 1 : i32
+  %c10 = arith.constant 10 : i32
+  %0:2 = scf.while (%arg1 = %arg0) : (i32) -> (i32, i32) {
+    %1 = arith.cmpi slt, %arg1, %c10 : i32
+    scf.condition(%1) %arg1, %c1 : i32, i32
+  } do {
+  ^bb0(%arg1: i32, %arg2: i32):
+    %1 = arith.addi %arg1, %arg2 : i32
+    scf.yield %1 : i32
+  }
+  // Expect the same loop twice (the dummy added by the test pass and the
+  // original one).
+  // CHECK: %[[V0:.*]] = scf.while (%[[arg1:.*]] = %[[arg0:.*]]) : (i32) -> (i32, i32) {
+  // CHECK: %[[V1:.*]] = scf.while (%[[arg2:.*]] = %[[arg0]]) : (i32) -> (i32, i32) {
+  return
+}

diff  --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
index 36c41ab0e93bd..22c2f2388de69 100644
--- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_library(MLIRSCFTestPasses
   TestLoopParametricTiling.cpp
   TestLoopUnrolling.cpp
   TestSCFUtils.cpp
+  TestWhileOpBuilder.cpp
 
   EXCLUDE_FROM_LIBMLIR
 

diff  --git a/mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp b/mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp
new file mode 100644
index 0000000000000..d3113c0b0ae7d
--- /dev/null
+++ b/mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp
@@ -0,0 +1,82 @@
+//===- TestWhileOpBuilder.cpp - Pass to test WhileOp::build ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to test some builder functions of WhileOp. It
+// tests the regression explained in https://reviews.llvm.org/D142952, where
+// a WhileOp::build overload crashed when fed with operands of 
diff erent types
+// than the result types.
+//
+// To test the build function, the pass copies each WhileOp found in the body
+// of a FuncOp and adds an additional WhileOp with the same operands and result
+// types (but dummy computations) using the builder in question.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::arith;
+using namespace mlir::scf;
+
+namespace {
+struct TestSCFWhileOpBuilderPass
+    : public PassWrapper<TestSCFWhileOpBuilderPass,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFWhileOpBuilderPass)
+
+  StringRef getArgument() const final { return "test-scf-while-op-builder"; }
+  StringRef getDescription() const final {
+    return "test build functions of scf.while";
+  }
+  explicit TestSCFWhileOpBuilderPass() = default;
+  TestSCFWhileOpBuilderPass(const TestSCFWhileOpBuilderPass &pass) = default;
+
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+    func.walk([&](WhileOp whileOp) {
+      Location loc = whileOp->getLoc();
+      ImplicitLocOpBuilder builder(loc, whileOp);
+
+      // Create a WhileOp with the same operands and result types.
+      TypeRange resultTypes = whileOp->getResultTypes();
+      ValueRange operands = whileOp->getOperands();
+      builder.create<WhileOp>(
+          loc, resultTypes, operands, /*beforeBuilder=*/
+          [&](OpBuilder &b, Location loc, ValueRange args) {
+            // Just cast the before args into the right types for condition.
+            ImplicitLocOpBuilder builder(loc, b);
+            auto castOp =
+                builder.create<UnrealizedConversionCastOp>(resultTypes, args);
+            auto cmp = builder.create<ConstantIntOp>(/*value=*/1, /*width=*/1);
+            builder.create<ConditionOp>(cmp, castOp->getResults());
+          },
+          /*afterBuilder=*/
+          [&](OpBuilder &b, Location loc, ValueRange args) {
+            // Just cast the after args into the right types for yield.
+            ImplicitLocOpBuilder builder(loc, b);
+            auto castOp = builder.create<UnrealizedConversionCastOp>(
+                operands.getTypes(), args);
+            builder.create<YieldOp>(castOp->getResults());
+          });
+    });
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestSCFWhileOpBuilderPass() {
+  PassRegistration<TestSCFWhileOpBuilderPass>();
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 6efbeb33d2e24..33f16241b2e49 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -113,6 +113,7 @@ void registerTestPDLLPasses();
 void registerTestPreparationPassWithAllowedMemrefResults();
 void registerTestRecursiveTypesPass();
 void registerTestSCFUtilsPass();
+void registerTestSCFWhileOpBuilderPass();
 void registerTestShapeMappingPass();
 void registerTestSliceAnalysisPass();
 void registerTestTensorCopyInsertionPass();
@@ -220,6 +221,7 @@ void registerTestPasses() {
   mlir::test::registerTestPDLLPasses();
   mlir::test::registerTestRecursiveTypesPass();
   mlir::test::registerTestSCFUtilsPass();
+  mlir::test::registerTestSCFWhileOpBuilderPass();
   mlir::test::registerTestShapeMappingPass();
   mlir::test::registerTestSliceAnalysisPass();
   mlir::test::registerTestTensorCopyInsertionPass();


        


More information about the Mlir-commits mailing list