[Mlir-commits] [mlir] [MLIR][Linalg] Harden parsing Linalg named ops (PR #145337)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 23 07:59:12 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Mehdi Amini (joker-eph)
<details>
<summary>Changes</summary>
This thread through proper error handling / reporting capabilities to avoid hitting llvm_unreachable while parsing linalg ops.
---
Patch is 33.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145337.diff
10 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/Linalg.h (+4)
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td (+2-1)
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td (+1-1)
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+33-17)
- (modified) mlir/lib/CAPI/Dialect/Linalg.cpp (+1-1)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+134-35)
- (modified) mlir/test/Dialect/Linalg/invalid.mlir (+10)
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+8-4)
- (modified) mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml (+16-9)
- (modified) mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp (+12-5)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index 57bf6305a469d..a0fb0111d6ace 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -16,6 +16,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
@@ -26,6 +27,9 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
+
+#include "llvm/ADT/STLFunctionalExtras.h"
+
#include <optional>
namespace mlir {
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 33601c5d6dad9..a459656b982e6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -52,7 +52,8 @@ def Linalg_Dialect : Dialect {
kMemoizedIndexingMapsAttrName = "linalg.memoized_indexing_maps";
using RegionBuilderFunType = llvm::function_ref<
- void(ImplicitLocOpBuilder &b, Block &, ArrayRef<NamedAttribute>)>;
+ void(ImplicitLocOpBuilder &b, Block &, ArrayRef<NamedAttribute>,
+ function_ref<InFlightDiagnostic()>)>;
RegionBuilderFunType getRegionBuilder(StringRef name) {
return namedStructuredOpRegionBuilders.lookup(name);
}
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 74c4c0a8835f2..594d6c757d7bd 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -842,7 +842,7 @@ def LinalgStructuredInterface
Returns a null function if this named op does not define a region
builder.
}],
- /*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>)>",
+ /*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>, function_ref<InFlightDiagnostic()>)>",
/*methodName=*/"getRegionBuilder",
(ins),
[{ return ConcreteOp::getRegionBuilder(); }]
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 61783812920bc..7bbc56f549c0b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -192,7 +192,8 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
}
static std::function<void(ImplicitLocOpBuilder &,
- Block &, ArrayRef<NamedAttribute>)>
+ Block &, ArrayRef<NamedAttribute>,
+ function_ref<InFlightDiagnostic()>)>
getRegionBuilder() {
return nullptr;
}
@@ -300,7 +301,8 @@ def MapOp : LinalgStructuredBase_Op<"map", [
}
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
- mlir::ArrayRef<mlir::NamedAttribute>)>
+ mlir::ArrayRef<mlir::NamedAttribute>,
+ function_ref<InFlightDiagnostic()>)>
getRegionBuilder() {
return nullptr;
}
@@ -380,7 +382,8 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
// Implement functions necessary for DestinationStyleOpInterface.
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
- mlir::ArrayRef<mlir::NamedAttribute>)>
+ mlir::ArrayRef<mlir::NamedAttribute>,
+ function_ref<InFlightDiagnostic()>)>
getRegionBuilder() {
return nullptr;
}
@@ -449,13 +452,14 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
- mlir::ArrayRef<mlir::NamedAttribute>) {
+ mlir::ArrayRef<mlir::NamedAttribute>, function_ref<InFlightDiagnostic()> emitError) {
OpBuilder::InsertionGuard guard(b);
b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
}
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
- mlir::ArrayRef<mlir::NamedAttribute>)>
+ mlir::ArrayRef<mlir::NamedAttribute>,
+ function_ref<InFlightDiagnostic()>)>
getRegionBuilder() {
return regionBuilder;
}
@@ -521,13 +525,15 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
- mlir::ArrayRef<mlir::NamedAttribute>) {
+ mlir::ArrayRef<mlir::NamedAttribute>,
+ function_ref<InFlightDiagnostic()> emitError) {
OpBuilder::InsertionGuard guard(b);
b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
}
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
- mlir::ArrayRef<mlir::NamedAttribute>)>
+ mlir::ArrayRef<mlir::NamedAttribute>,
+ function_ref<InFlightDiagnostic()>)>
getRegionBuilder() {
return regionBuilder;
}
@@ -631,10 +637,12 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
/// Implements the block region builder for the elementwiseOp. This is
/// called by the 'fillStructuredOpRegion'.
static void regionBuilder(ImplicitLocOpBuilder &b,
- Block &block, ArrayRef<NamedAttribute> attrs);
+ Block &block, ArrayRef<NamedAttribute> attrs,
+ function_ref<InFlightDiagnostic()> emitError);
static std::function<void(ImplicitLocOpBuilder &,
- Block &, ArrayRef<NamedAttribute>)>
+ Block &, ArrayRef<NamedAttribute>,
+ function_ref<InFlightDiagnostic()>)>
getRegionBuilder() {
return regionBuilder;
}
@@ -771,7 +779,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
/// Implements the block region builder.
static void regionBuilder(ImplicitLocOpBuilder &b,
- Block &block, ArrayRef<NamedAttribute> attrs);
+ Block &block, ArrayRef<NamedAttribute> attrs,
+ function_ref<InFlightDiagnostic()> emitError);
/// Returns a list of AffineMap with the default matmul indexing charactristic.
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
@@ -780,7 +789,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
static std::function<void(ImplicitLocOpBuilder &,
- Block &, ArrayRef<NamedAttribute>)>
+ Block &, ArrayRef<NamedAttribute>,
+ function_ref<InFlightDiagnostic()>)>
getRegionBuilder() {
return regionBuilder;
}
@@ -916,10 +926,12 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
static unsigned getNumRegionArgs();
static void regionBuilder(ImplicitLocOpBuilder &b,
- Block &block, ArrayRef<NamedAttribute> attrs);
+ Block &block, ArrayRef<NamedAttribute> attrs,
+ function_ref<InFlightDiagnostic()> emitError);
static std::function<void(ImplicitLocOpBuilder &,
- Block &, ArrayRef<NamedAttribute>)>
+ Block &, ArrayRef<NamedAttribute>,
+ function_ref<InFlightDiagnostic()>)>
getRegionBuilder() {
return regionBuilder;
}
@@ -1033,9 +1045,11 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
SmallVector<utils::IteratorType> getIteratorTypesArray();
static void regionBuilder(ImplicitLocOpBuilder &b,
- Block &block, ArrayRef<NamedAttribute> attrs);
+ Block &block, ArrayRef<NamedAttribute> attrs,
+ function_ref<InFlightDiagnostic()> emitError);
static std::function<void(ImplicitLocOpBuilder &,
- Block &, ArrayRef<NamedAttribute>)>
+ Block &, ArrayRef<NamedAttribute>,
+ function_ref<InFlightDiagnostic()>)>
getRegionBuilder() {
return regionBuilder;
}
@@ -1161,7 +1175,8 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
/// Implements the block region builder.
static void regionBuilder(ImplicitLocOpBuilder &b,
- Block &block, ArrayRef<NamedAttribute> attrs);
+ Block &block, ArrayRef<NamedAttribute> attrs,
+ function_ref<InFlightDiagnostic()> emitError);
/// Returns a list of AffineMap with the default batch_reduce_matmul indexing charactristic.
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
@@ -1170,7 +1185,8 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
static std::function<void(ImplicitLocOpBuilder &,
- Block &, ArrayRef<NamedAttribute>)>
+ Block &, ArrayRef<NamedAttribute>,
+ function_ref<InFlightDiagnostic()>)>
getRegionBuilder() {
return regionBuilder;
}
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 0c4f6e88e7078..21db18dfd47ed 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -38,7 +38,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
Region ®ion = op->getRegion(0);
Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs);
b.setInsertionPointToStart(body);
- fun(b, *body, op->getAttrs());
+ fun(b, *body, op->getAttrs(), /*emitError=*/{});
}
MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5dbb2403eddbd..9cc60394e6635 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -117,8 +117,9 @@ OpFoldResult linalg::createFoldedDimOp(OpBuilder &b, Location loc, Value source,
// Support for named Linalg ops defined in ods-gen.
//===----------------------------------------------------------------------===//
-using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
- ArrayRef<NamedAttribute>)>;
+using RegionBuilderFn = llvm::function_ref<void(
+ ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>,
+ function_ref<InFlightDiagnostic()>)>;
/// Fills the region of a structured operation using the provided
/// `regionBuilder`. The method is used by both named structured ops created by
@@ -128,6 +129,7 @@ using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
TypeRange inputTypes, TypeRange outputTypes,
ArrayRef<NamedAttribute> attrs,
+ function_ref<InFlightDiagnostic()> emitError,
RegionBuilderFn regionBuilder) {
SmallVector<Type, 8> argTypes;
SmallVector<Location, 8> argLocs;
@@ -148,7 +150,7 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
opBuilder.setInsertionPointToStart(body);
ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
- regionBuilder(b, *body, attrs);
+ regionBuilder(b, *body, attrs, emitError);
// indexing_maps is an auto-generated method.
@@ -184,7 +186,8 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
// Create and fill the region of the structured operation.
Region ®ion = *state.addRegion();
fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
- state.attributes.getAttrs(), regionBuilder);
+ state.attributes.getAttrs(), /*emitError=*/{},
+ regionBuilder);
}
static void buildMatmulOp(OpBuilder &b, OperationState &state,
@@ -339,9 +342,15 @@ static ParseResult parseNamedStructuredOpRegion(
}
OpBuilder opBuilder(parser.getContext());
- fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
- regionBuilder);
- return success();
+ ParseResult result = success();
+ fillStructuredOpRegion(
+ opBuilder, region, inputTypes, outputTypes, attrs,
+ [&]() {
+ result = failure();
+ return parser.emitError(parser.getCurrentLocation());
+ },
+ regionBuilder);
+ return result;
}
static ParseResult
@@ -435,9 +444,15 @@ class RegionBuilderHelper {
: builder(builder), block(block) {}
// Build the unary functions defined by OpDSL.
- Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
- if (!isFloatingPoint(arg))
+ Value buildUnaryFn(UnaryFn unaryFn, Value arg,
+ function_ref<InFlightDiagnostic()> emitError = {}) {
+ if (!isFloatingPoint(arg)) {
+ if (emitError) {
+ emitError() << "unsupported non numeric type";
+ return nullptr;
+ }
llvm_unreachable("unsupported non numeric type");
+ }
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToEnd(&block);
switch (unaryFn) {
@@ -472,18 +487,34 @@ class RegionBuilderHelper {
case UnaryFn::erf:
return builder.create<math::ErfOp>(arg.getLoc(), arg);
}
+ if (emitError) {
+ emitError() << "unsupported unary function";
+ return nullptr;
+ }
llvm_unreachable("unsupported unary function");
}
// Build the binary functions defined by OpDSL.
- Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
+ // If emitError is provided, an error will be emitted if the operation is not
+ // supported and a nullptr will be returned, otherwise an assertion will be
+ // raised.
+ Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
+ function_ref<InFlightDiagnostic()> emitError = {}) {
bool allComplex = isComplex(arg0) && isComplex(arg1);
bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
bool allInteger = isInteger(arg0) && isInteger(arg1);
bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
arg1.getType().getIntOrFloatBitWidth() == 1;
- if (!allComplex && !allFloatingPoint && !allInteger)
+ if (!allComplex && !allFloatingPoint && !allInteger) {
+ if (emitError) {
+ emitError()
+ << "Cannot build binary Linalg operation: expects allComplex, "
+ "allFloatingPoint, or allInteger, got "
+ << arg0.getType() << " and " << arg1.getType();
+ return nullptr;
+ }
llvm_unreachable("unsupported non numeric type");
+ }
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToEnd(&block);
switch (binaryFn) {
@@ -500,8 +531,13 @@ class RegionBuilderHelper {
return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
if (allFloatingPoint)
return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
- if (allBool)
+ if (allBool) {
+ if (emitError) {
+ emitError() << "unsupported operation: sub with bools";
+ return nullptr;
+ }
llvm_unreachable("unsupported operation: sub with bools");
+ }
return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::mul:
if (allComplex)
@@ -516,12 +552,22 @@ class RegionBuilderHelper {
return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
if (allFloatingPoint)
return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
- if (allBool)
+ if (allBool) {
+ if (emitError) {
+ emitError() << "unsupported operation: div with bools";
+ return nullptr;
+ }
llvm_unreachable("unsupported operation: div with bools");
+ }
return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::div_unsigned:
- if (!allInteger || allBool)
+ if (!allInteger || allBool) {
+ if (emitError) {
+ emitError() << "unsupported operation: unsigned div not on uint";
+ return nullptr;
+ }
llvm_unreachable("unsupported operation: unsigned div not on uint");
+ }
return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::max_signed:
assert(!allComplex);
@@ -547,12 +593,16 @@ class RegionBuilderHelper {
assert(allFloatingPoint);
return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
}
+ if (emitError) {
+ emitError() << "unsupported binary function";
+ return nullptr;
+ }
llvm_unreachable("unsupported binary function");
}
// Build the ternary functions defined by OpDSL.
- Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
- Value arg2) {
+ Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
+ function_ref<InFlightDiagnostic()> emitError = {}) {
bool headBool =
isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
bool tailFloatingPoint =
@@ -566,17 +616,26 @@ class RegionBuilderHelper {
llvm_unreachable("unsupported non numeric type");
return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
}
+ if (emitError) {
+ emitError() << "unsupported ternary function";
+ return nullptr;
+ }
llvm_unreachable("unsupported ternary function");
}
// Build the type functions defined by OpDSL.
- Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
+ Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
+ function_ref<InFlightDiagnostic()> emitError = {}) {
switch (typeFn) {
case TypeFn::cast_signed:
return cast(toType, operand, false);
case TypeFn::cast_unsigned:
return cast(toType, operand, true);
}
+ if (emitError) {
+ emitError() << "unsupported type conversion function";
+ return nullptr;
+ }
llvm_unreachable("unsupported type conversion function");
}
@@ -3664,9 +3723,15 @@ bool MatmulOp::hasUserDefinedMaps() {
/// Implements the block region builder for the MatmulOp. This is called by
/// 'fillStructuredOpRegion'.
void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
- ArrayRef<NamedAttribute> attrs) {
- assert(3 > 0 && block.getNumArguments() == 3 &&
- "MatmulOp regionBuilder expects 3 (>=0) args");
+ ArrayRef<NamedAttribute> attrs,
+ function_ref<InFlightDiagnostic()> emitError) {
+ if (emitError && block.getNumArguments() != 3) {
+ emitError() << "MatmulOp regionBuilder expects 3 args, got "
+ << block.getNumArguments();
+ return;
+ }
+ assert(block.getNumArguments() == 3 &&
+ "MatmulOp regionBuilder expects 3 args");
RegionBuilderHelper helper(b, block);
SmallVector<Value> yields;
@@ -3683,9 +3748,13 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
block.getArgument(0));
Value value2 = helper.buildTypeFn(cast...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/145337
More information about the Mlir-commits
mailing list