[Mlir-commits] [mlir] a4c4705 - [mlir][linalg] Fix builder API usage in `RegionBuilderHelper` (#87451)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 3 19:18:02 PDT 2024


Author: Matthias Springer
Date: 2024-04-04T11:17:59+09:00
New Revision: a4c470555b5c311770e6cb58494c573c4efe53d6

URL: https://github.com/llvm/llvm-project/commit/a4c470555b5c311770e6cb58494c573c4efe53d6
DIFF: https://github.com/llvm/llvm-project/commit/a4c470555b5c311770e6cb58494c573c4efe53d6.diff

LOG: [mlir][linalg] Fix builder API usage in `RegionBuilderHelper` (#87451)

Operations must be created with the supplied builder. Otherwise, the
dialect conversion / greedy pattern rewrite driver can break.

This commit fixes a crash in the dialect conversion:
```
within split at llvm-project/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir:1 offset :8:8: error: failed to legalize operation 'tosa.add'
  %0 = tosa.add %1, %arg2 : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
       ^
within split at llvm-project/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir:1 offset :8:8: note: see current operation: %9 = "tosa.add"(%8, %arg2) : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
mlir-opt: llvm-project/mlir/include/mlir/IR/UseDefLists.h:198: mlir::IRObjectWithUseList<mlir::OpOperand>::~IRObjectWithUseList() [OperandType = mlir::OpOperand]: Assertion `use_empty() && "Cannot destroy a value that still has uses!"' failed.
```

This commit is the proper fix for #87297 (which was reverted).

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 2d7219fef87c64..9c5c58fa1fabfb 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -373,14 +373,15 @@ namespace {
 
 class RegionBuilderHelper {
 public:
-  RegionBuilderHelper(MLIRContext *context, Block &block)
-      : context(context), block(block) {}
+  RegionBuilderHelper(OpBuilder &builder, Block &block)
+      : builder(builder), block(block) {}
 
   // Build the unary functions defined by OpDSL.
   Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
     if (!isFloatingPoint(arg))
       llvm_unreachable("unsupported non numeric type");
-    OpBuilder builder = getBuilder();
+    OpBuilder::InsertionGuard g(builder);
+    builder.setInsertionPointToEnd(&block);
     switch (unaryFn) {
     case UnaryFn::exp:
       return builder.create<math::ExpOp>(arg.getLoc(), arg);
@@ -407,7 +408,8 @@ class RegionBuilderHelper {
                    arg1.getType().getIntOrFloatBitWidth() == 1;
     if (!allComplex && !allFloatingPoint && !allInteger)
       llvm_unreachable("unsupported non numeric type");
-    OpBuilder builder = getBuilder();
+    OpBuilder::InsertionGuard g(builder);
+    builder.setInsertionPointToEnd(&block);
     switch (binaryFn) {
     case BinaryFn::add:
       if (allComplex)
@@ -481,29 +483,32 @@ class RegionBuilderHelper {
   }
 
   void yieldOutputs(ValueRange values) {
-    OpBuilder builder = getBuilder();
+    OpBuilder::InsertionGuard g(builder);
+    builder.setInsertionPointToEnd(&block);
     Location loc = builder.getUnknownLoc();
     builder.create<YieldOp>(loc, values);
   }
 
   Value constant(const std::string &value) {
-    OpBuilder builder = getBuilder();
+    OpBuilder::InsertionGuard g(builder);
+    builder.setInsertionPointToEnd(&block);
     Location loc = builder.getUnknownLoc();
     Attribute valueAttr = parseAttribute(value, builder.getContext());
     return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
   }
 
   Value index(int64_t dim) {
-    OpBuilder builder = getBuilder();
+    OpBuilder::InsertionGuard g(builder);
+    builder.setInsertionPointToEnd(&block);
     return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
   }
 
   Type getIntegerType(unsigned width) {
-    return IntegerType::get(context, width);
+    return IntegerType::get(builder.getContext(), width);
   }
 
-  Type getFloat32Type() { return Float32Type::get(context); }
-  Type getFloat64Type() { return Float64Type::get(context); }
+  Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
+  Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
 
 private:
   // Generates operations to cast the given operand to a specified type.
@@ -511,7 +516,8 @@ class RegionBuilderHelper {
   // operand returned as-is (which will presumably yield a verification
   // issue downstream).
   Value cast(Type toType, Value operand, bool isUnsignedCast) {
-    OpBuilder builder = getBuilder();
+    OpBuilder::InsertionGuard g(builder);
+    builder.setInsertionPointToEnd(&block);
     auto loc = operand.getLoc();
     return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
   }
@@ -526,13 +532,7 @@ class RegionBuilderHelper {
     return llvm::isa<IntegerType>(value.getType());
   }
 
-  OpBuilder getBuilder() {
-    OpBuilder builder(context);
-    builder.setInsertionPointToEnd(&block);
-    return builder;
-  }
-
-  MLIRContext *context;
+  OpBuilder &builder;
   Block █
 };
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index d4c17928d4ca15..acea25f023980a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -274,7 +274,7 @@ struct SparseTensorCodegenPass
         });
     // The following operations and dialects may be introduced by the
     // codegen rules, and are therefore marked as legal.
-    target.addLegalOp<linalg::FillOp>();
+    target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
     target.addLegalDialect<
         arith::ArithDialect, bufferization::BufferizationDialect,
         complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 17eec593691860..ad65410e635e9c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -15,3 +15,15 @@ func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
   %0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8>
   return %0 : tensor<*xi8>
 }
+
+// -----
+
+// CHECK-LABEL: @unranked_add
+func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %arg2 : tensor<*xf32>) -> (tensor<10x10xf32>) {
+  // expected-error at +3 {{failed to legalize operation 'tosa.add'}}
+  %reduce = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<10x10xf32>) -> tensor<10x1xf32>
+  %1 = tosa.add %reduce, %arg1 : (tensor<10x1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
+  %0 = tosa.add %1, %arg2 : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
+  %2 = tosa.reshape %0 {new_shape = array<i64: 10, 10>} : (tensor<*xf32>) -> tensor<10x10xf32>
+  return %2 : tensor<10x10xf32>
+}

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index f14e559fff92f3..fe6ad150411261 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -1008,7 +1008,7 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
                         Block &block, ArrayRef<NamedAttribute> attrs) {{
   assert({1} > 0 && block.getNumArguments() == {1} &&
          "{0} regionBuilder expects {1} (>=0) args");
-  RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
+  RegionBuilderHelper helper(b, block);
   SmallVector<Value> yields;
   {2}
   {3}


        


More information about the Mlir-commits mailing list