[Mlir-commits] [mlir] a4c4705 - [mlir][linalg] Fix builder API usage in `RegionBuilderHelper` (#87451)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 3 19:18:02 PDT 2024
Author: Matthias Springer
Date: 2024-04-04T11:17:59+09:00
New Revision: a4c470555b5c311770e6cb58494c573c4efe53d6
URL: https://github.com/llvm/llvm-project/commit/a4c470555b5c311770e6cb58494c573c4efe53d6
DIFF: https://github.com/llvm/llvm-project/commit/a4c470555b5c311770e6cb58494c573c4efe53d6.diff
LOG: [mlir][linalg] Fix builder API usage in `RegionBuilderHelper` (#87451)
Operations must be created with the supplied builder. Otherwise, the
dialect conversion / greedy pattern rewrite driver can break.
This commit fixes a crash in the dialect conversion:
```
within split at llvm-project/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir:1 offset :8:8: error: failed to legalize operation 'tosa.add'
%0 = tosa.add %1, %arg2 : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
^
within split at llvm-project/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir:1 offset :8:8: note: see current operation: %9 = "tosa.add"(%8, %arg2) : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
mlir-opt: llvm-project/mlir/include/mlir/IR/UseDefLists.h:198: mlir::IRObjectWithUseList<mlir::OpOperand>::~IRObjectWithUseList() [OperandType = mlir::OpOperand]: Assertion `use_empty() && "Cannot destroy a value that still has uses!"' failed.
```
This commit is the proper fix for #87297 (which was reverted).
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 2d7219fef87c64..9c5c58fa1fabfb 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -373,14 +373,15 @@ namespace {
class RegionBuilderHelper {
public:
- RegionBuilderHelper(MLIRContext *context, Block &block)
- : context(context), block(block) {}
+ RegionBuilderHelper(OpBuilder &builder, Block &block)
+ : builder(builder), block(block) {}
// Build the unary functions defined by OpDSL.
Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
if (!isFloatingPoint(arg))
llvm_unreachable("unsupported non numeric type");
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
switch (unaryFn) {
case UnaryFn::exp:
return builder.create<math::ExpOp>(arg.getLoc(), arg);
@@ -407,7 +408,8 @@ class RegionBuilderHelper {
arg1.getType().getIntOrFloatBitWidth() == 1;
if (!allComplex && !allFloatingPoint && !allInteger)
llvm_unreachable("unsupported non numeric type");
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
switch (binaryFn) {
case BinaryFn::add:
if (allComplex)
@@ -481,29 +483,32 @@ class RegionBuilderHelper {
}
void yieldOutputs(ValueRange values) {
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
Location loc = builder.getUnknownLoc();
builder.create<YieldOp>(loc, values);
}
Value constant(const std::string &value) {
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
Location loc = builder.getUnknownLoc();
Attribute valueAttr = parseAttribute(value, builder.getContext());
return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
}
Value index(int64_t dim) {
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
}
Type getIntegerType(unsigned width) {
- return IntegerType::get(context, width);
+ return IntegerType::get(builder.getContext(), width);
}
- Type getFloat32Type() { return Float32Type::get(context); }
- Type getFloat64Type() { return Float64Type::get(context); }
+ Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
+ Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
private:
// Generates operations to cast the given operand to a specified type.
@@ -511,7 +516,8 @@ class RegionBuilderHelper {
// operand returned as-is (which will presumably yield a verification
// issue downstream).
Value cast(Type toType, Value operand, bool isUnsignedCast) {
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
auto loc = operand.getLoc();
return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
}
@@ -526,13 +532,7 @@ class RegionBuilderHelper {
return llvm::isa<IntegerType>(value.getType());
}
- OpBuilder getBuilder() {
- OpBuilder builder(context);
- builder.setInsertionPointToEnd(&block);
- return builder;
- }
-
- MLIRContext *context;
+ OpBuilder &builder;
Block █
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index d4c17928d4ca15..acea25f023980a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -274,7 +274,7 @@ struct SparseTensorCodegenPass
});
// The following operations and dialects may be introduced by the
// codegen rules, and are therefore marked as legal.
- target.addLegalOp<linalg::FillOp>();
+ target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
target.addLegalDialect<
arith::ArithDialect, bufferization::BufferizationDialect,
complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 17eec593691860..ad65410e635e9c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -15,3 +15,15 @@ func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
%0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8>
return %0 : tensor<*xi8>
}
+
+// -----
+
+// CHECK-LABEL: @unranked_add
+func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %arg2 : tensor<*xf32>) -> (tensor<10x10xf32>) {
+ // expected-error at +3 {{failed to legalize operation 'tosa.add'}}
+ %reduce = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<10x10xf32>) -> tensor<10x1xf32>
+ %1 = tosa.add %reduce, %arg1 : (tensor<10x1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
+ %0 = tosa.add %1, %arg2 : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
+ %2 = tosa.reshape %0 {new_shape = array<i64: 10, 10>} : (tensor<*xf32>) -> tensor<10x10xf32>
+ return %2 : tensor<10x10xf32>
+}
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index f14e559fff92f3..fe6ad150411261 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -1008,7 +1008,7 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs) {{
assert({1} > 0 && block.getNumArguments() == {1} &&
"{0} regionBuilder expects {1} (>=0) args");
- RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
+ RegionBuilderHelper helper(b, block);
SmallVector<Value> yields;
{2}
{3}
More information about the Mlir-commits
mailing list