[Mlir-commits] [mlir] [MLIR][Linalg] Harden parsing Linalg named ops (PR #145337)

Mehdi Amini llvmlistbot at llvm.org
Tue Jun 24 07:41:12 PDT 2025


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/145337

>From 8feef28799c7bd4f35a8cc4b028a5e4029e02b9a Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Mon, 23 Jun 2025 07:25:48 -0700
Subject: [PATCH] [MLIR][Linalg] Harden parsing Linalg named ops

This thread through proper error handling / reporting capabilities to
avoid hitting llvm_unreachable while parsing linalg ops.
---
 mlir/include/mlir/Dialect/Linalg/IR/Linalg.h  |   4 +
 .../mlir/Dialect/Linalg/IR/LinalgBase.td      |   3 +-
 .../Dialect/Linalg/IR/LinalgInterfaces.td     |   2 +-
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  50 +++--
 mlir/lib/CAPI/Dialect/Linalg.cpp              |   2 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 181 ++++++++++++++----
 mlir/test/Dialect/Linalg/invalid.mlir         |  40 ++++
 mlir/test/lib/Dialect/Test/TestOps.td         |  12 +-
 .../test-linalg-ods-yaml-gen.yaml             |  25 ++-
 .../mlir-linalg-ods-yaml-gen.cpp              |  17 +-
 10 files changed, 261 insertions(+), 75 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index 57bf6305a469d..a0fb0111d6ace 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -16,6 +16,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -26,6 +27,9 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
+
+#include "llvm/ADT/STLFunctionalExtras.h"
+
 #include <optional>
 
 namespace mlir {
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 33601c5d6dad9..a459656b982e6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -52,7 +52,8 @@ def Linalg_Dialect : Dialect {
         kMemoizedIndexingMapsAttrName = "linalg.memoized_indexing_maps";
 
     using RegionBuilderFunType = llvm::function_ref<
-      void(ImplicitLocOpBuilder &b, Block &, ArrayRef<NamedAttribute>)>;
+      void(ImplicitLocOpBuilder &b, Block &, ArrayRef<NamedAttribute>,
+           function_ref<InFlightDiagnostic()>)>;
     RegionBuilderFunType getRegionBuilder(StringRef name) {
       return namedStructuredOpRegionBuilders.lookup(name);
     }
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index ca1cba8747bd8..ba73cfbbed845 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -720,7 +720,7 @@ def LinalgStructuredInterface
         Returns a null function if this named op does not define a region
         builder.
       }],
-      /*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>)>",
+      /*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>, function_ref<InFlightDiagnostic()>)>",
       /*methodName=*/"getRegionBuilder",
       (ins),
       [{ return ConcreteOp::getRegionBuilder(); }]
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 61783812920bc..7bbc56f549c0b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -192,7 +192,8 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
     }
 
     static std::function<void(ImplicitLocOpBuilder &,
-                              Block &, ArrayRef<NamedAttribute>)>
+                              Block &, ArrayRef<NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
     getRegionBuilder() {
       return nullptr;
     }
@@ -300,7 +301,8 @@ def MapOp : LinalgStructuredBase_Op<"map", [
     }
 
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
-                              mlir::ArrayRef<mlir::NamedAttribute>)>
+                              mlir::ArrayRef<mlir::NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
     getRegionBuilder() {
       return nullptr;
     }
@@ -380,7 +382,8 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
 
     // Implement functions necessary for DestinationStyleOpInterface.
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
-                              mlir::ArrayRef<mlir::NamedAttribute>)>
+                              mlir::ArrayRef<mlir::NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
     getRegionBuilder() {
       return nullptr;
     }
@@ -449,13 +452,14 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
     MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
 
     static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
-        mlir::ArrayRef<mlir::NamedAttribute>) {
+        mlir::ArrayRef<mlir::NamedAttribute>, function_ref<InFlightDiagnostic()> emitError) {
       OpBuilder::InsertionGuard guard(b);
       b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
     }
 
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
-        mlir::ArrayRef<mlir::NamedAttribute>)>
+                              mlir::ArrayRef<mlir::NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
       getRegionBuilder() {
       return regionBuilder;
     }
@@ -521,13 +525,15 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
     MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
 
     static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
-        mlir::ArrayRef<mlir::NamedAttribute>) {
+                              mlir::ArrayRef<mlir::NamedAttribute>, 
+                              function_ref<InFlightDiagnostic()> emitError) {
       OpBuilder::InsertionGuard guard(b);
       b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
     }
 
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
-        mlir::ArrayRef<mlir::NamedAttribute>)>
+                              mlir::ArrayRef<mlir::NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
       getRegionBuilder() {
       return regionBuilder;
     }
@@ -631,10 +637,12 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
       /// Implements the block region builder for the elementwiseOp. This is
       /// called by the 'fillStructuredOpRegion'.
       static void regionBuilder(ImplicitLocOpBuilder &b,
-                                Block &block, ArrayRef<NamedAttribute> attrs);
+                                Block &block, ArrayRef<NamedAttribute> attrs,
+                                function_ref<InFlightDiagnostic()> emitError);
 
       static std::function<void(ImplicitLocOpBuilder &,
-                                Block &, ArrayRef<NamedAttribute>)>
+                                Block &, ArrayRef<NamedAttribute>,
+                                function_ref<InFlightDiagnostic()>)>
       getRegionBuilder() {
         return regionBuilder;
       }
@@ -771,7 +779,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
 
       /// Implements the block region builder.
       static void regionBuilder(ImplicitLocOpBuilder &b,
-                                Block &block, ArrayRef<NamedAttribute> attrs);
+                                Block &block, ArrayRef<NamedAttribute> attrs,
+                                function_ref<InFlightDiagnostic()> emitError);
 
       /// Returns a list of AffineMap with the default matmul indexing charactristic.
       static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
@@ -780,7 +789,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
       bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
 
       static std::function<void(ImplicitLocOpBuilder &,
-                                Block &, ArrayRef<NamedAttribute>)>
+                                Block &, ArrayRef<NamedAttribute>,
+                                function_ref<InFlightDiagnostic()>)>
       getRegionBuilder() {
         return regionBuilder;
       }
@@ -916,10 +926,12 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
     static unsigned getNumRegionArgs();
 
     static void regionBuilder(ImplicitLocOpBuilder &b,
-                              Block &block, ArrayRef<NamedAttribute> attrs);
+                              Block &block, ArrayRef<NamedAttribute> attrs,
+                              function_ref<InFlightDiagnostic()> emitError);
 
     static std::function<void(ImplicitLocOpBuilder &,
-                              Block &, ArrayRef<NamedAttribute>)>
+                              Block &, ArrayRef<NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
     getRegionBuilder() {
       return regionBuilder;
     }
@@ -1033,9 +1045,11 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
 
       SmallVector<utils::IteratorType> getIteratorTypesArray();
       static void regionBuilder(ImplicitLocOpBuilder &b,
-                                Block &block, ArrayRef<NamedAttribute> attrs);
+                                Block &block, ArrayRef<NamedAttribute> attrs,
+                                function_ref<InFlightDiagnostic()> emitError);
       static std::function<void(ImplicitLocOpBuilder &,
-                                Block &, ArrayRef<NamedAttribute>)>
+                                Block &, ArrayRef<NamedAttribute>,
+                                function_ref<InFlightDiagnostic()>)>
       getRegionBuilder() {
         return regionBuilder;
       }
@@ -1161,7 +1175,8 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
 
       /// Implements the block region builder.
       static void regionBuilder(ImplicitLocOpBuilder &b,
-                                Block &block, ArrayRef<NamedAttribute> attrs);
+                                Block &block, ArrayRef<NamedAttribute> attrs,
+                                function_ref<InFlightDiagnostic()> emitError);
 
       /// Returns a list of AffineMap with the default batch_reduce_matmul indexing charactristic.
       static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
@@ -1170,7 +1185,8 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
       bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
 
       static std::function<void(ImplicitLocOpBuilder &,
-                                Block &, ArrayRef<NamedAttribute>)>
+                                Block &, ArrayRef<NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
       getRegionBuilder() {
         return regionBuilder;
       }
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 0c4f6e88e7078..21db18dfd47ed 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -38,7 +38,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
   Region &region = op->getRegion(0);
   Block *body = b.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
   b.setInsertionPointToStart(body);
-  fun(b, *body, op->getAttrs());
+  fun(b, *body, op->getAttrs(), /*emitError=*/{});
 }
 
 MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5dbb2403eddbd..5ab44607d6c4a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -117,8 +117,9 @@ OpFoldResult linalg::createFoldedDimOp(OpBuilder &b, Location loc, Value source,
 // Support for named Linalg ops defined in ods-gen.
 //===----------------------------------------------------------------------===//
 
-using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
-                                                ArrayRef<NamedAttribute>)>;
+using RegionBuilderFn = llvm::function_ref<void(
+    ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>,
+    function_ref<InFlightDiagnostic()>)>;
 
 /// Fills the region of a structured operation using the provided
 /// `regionBuilder`. The method is used by both named structured ops created by
@@ -128,6 +129,7 @@ using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
                                    TypeRange inputTypes, TypeRange outputTypes,
                                    ArrayRef<NamedAttribute> attrs,
+                                   function_ref<InFlightDiagnostic()> emitError,
                                    RegionBuilderFn regionBuilder) {
   SmallVector<Type, 8> argTypes;
   SmallVector<Location, 8> argLocs;
@@ -148,7 +150,7 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
 
   opBuilder.setInsertionPointToStart(body);
   ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
-  regionBuilder(b, *body, attrs);
+  regionBuilder(b, *body, attrs, emitError);
 
   // indexing_maps is an auto-generated method.
 
@@ -184,7 +186,8 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
   // Create and fill the region of the structured operation.
   Region &region = *state.addRegion();
   fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
-                         state.attributes.getAttrs(), regionBuilder);
+                         state.attributes.getAttrs(), /*emitError=*/{},
+                         regionBuilder);
 }
 
 static void buildMatmulOp(OpBuilder &b, OperationState &state,
@@ -329,7 +332,7 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
 static ParseResult parseNamedStructuredOpRegion(
     OpAsmParser &parser, Region &region, unsigned numRegionArgs,
     TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
-    RegionBuilderFn regionBuilder) {
+    RegionBuilderFn regionBuilder, SMLoc loc) {
   if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
     return parser.emitError(
         parser.getCurrentLocation(),
@@ -339,9 +342,15 @@ static ParseResult parseNamedStructuredOpRegion(
   }
 
   OpBuilder opBuilder(parser.getContext());
-  fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
-                         regionBuilder);
-  return success();
+  ParseResult result = success();
+  fillStructuredOpRegion(
+      opBuilder, region, inputTypes, outputTypes, attrs,
+      [&]() {
+        result = failure();
+        return parser.emitError(loc);
+      },
+      regionBuilder);
+  return result;
 }
 
 static ParseResult
@@ -358,6 +367,7 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
                                           RegionBuilderFn regionBuilder) {
   // TODO: Enable when ods-gen supports captures.
   SmallVector<Type, 1> inputTypes, outputTypes;
+  SMLoc loc = parser.getCurrentLocation();
   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
     return failure();
 
@@ -375,7 +385,7 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
   std::unique_ptr<Region> region = std::make_unique<Region>();
   if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
                                    outputTypes, result.attributes.getAttrs(),
-                                   regionBuilder))
+                                   regionBuilder, loc))
     return failure();
   result.addRegion(std::move(region));
 
@@ -435,9 +445,15 @@ class RegionBuilderHelper {
       : builder(builder), block(block) {}
 
   // Build the unary functions defined by OpDSL.
-  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
-    if (!isFloatingPoint(arg))
+  Value buildUnaryFn(UnaryFn unaryFn, Value arg,
+                     function_ref<InFlightDiagnostic()> emitError = {}) {
+    if (!isFloatingPoint(arg)) {
+      if (emitError) {
+        emitError() << "unsupported non numeric type";
+        return nullptr;
+      }
       llvm_unreachable("unsupported non numeric type");
+    }
     OpBuilder::InsertionGuard g(builder);
     builder.setInsertionPointToEnd(&block);
     switch (unaryFn) {
@@ -472,18 +488,34 @@ class RegionBuilderHelper {
     case UnaryFn::erf:
       return builder.create<math::ErfOp>(arg.getLoc(), arg);
     }
+    if (emitError) {
+      emitError() << "unsupported unary function";
+      return nullptr;
+    }
     llvm_unreachable("unsupported unary function");
   }
 
   // Build the binary functions defined by OpDSL.
-  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
+  // If emitError is provided, an error will be emitted if the operation is not
+  // supported and a nullptr will be returned, otherwise an assertion will be
+  // raised.
+  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
+                      function_ref<InFlightDiagnostic()> emitError = {}) {
     bool allComplex = isComplex(arg0) && isComplex(arg1);
     bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
     bool allInteger = isInteger(arg0) && isInteger(arg1);
     bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
                    arg1.getType().getIntOrFloatBitWidth() == 1;
-    if (!allComplex && !allFloatingPoint && !allInteger)
+    if (!allComplex && !allFloatingPoint && !allInteger) {
+      if (emitError) {
+        emitError()
+            << "Cannot build binary Linalg operation: expects allComplex, "
+               "allFloatingPoint, or allInteger, got "
+            << arg0.getType() << " and " << arg1.getType();
+        return nullptr;
+      }
       llvm_unreachable("unsupported non numeric type");
+    }
     OpBuilder::InsertionGuard g(builder);
     builder.setInsertionPointToEnd(&block);
     switch (binaryFn) {
@@ -500,8 +532,13 @@ class RegionBuilderHelper {
         return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
       if (allFloatingPoint)
         return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
-      if (allBool)
+      if (allBool) {
+        if (emitError) {
+          emitError() << "unsupported operation: sub with bools";
+          return nullptr;
+        }
         llvm_unreachable("unsupported operation: sub with bools");
+      }
       return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
     case BinaryFn::mul:
       if (allComplex)
@@ -516,12 +553,22 @@ class RegionBuilderHelper {
         return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
       if (allFloatingPoint)
         return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
-      if (allBool)
+      if (allBool) {
+        if (emitError) {
+          emitError() << "unsupported operation: div with bools";
+          return nullptr;
+        }
         llvm_unreachable("unsupported operation: div with bools");
+      }
       return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
     case BinaryFn::div_unsigned:
-      if (!allInteger || allBool)
+      if (!allInteger || allBool) {
+        if (emitError) {
+          emitError() << "unsupported operation: unsigned div not on uint";
+          return nullptr;
+        }
         llvm_unreachable("unsupported operation: unsigned div not on uint");
+      }
       return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
     case BinaryFn::max_signed:
       assert(!allComplex);
@@ -547,12 +594,16 @@ class RegionBuilderHelper {
       assert(allFloatingPoint);
       return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
     }
+    if (emitError) {
+      emitError() << "unsupported binary function";
+      return nullptr;
+    }
     llvm_unreachable("unsupported binary function");
   }
 
   // Build the ternary functions defined by OpDSL.
-  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
-                       Value arg2) {
+  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
+                       function_ref<InFlightDiagnostic()> emitError = {}) {
     bool headBool =
         isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
     bool tailFloatingPoint =
@@ -566,17 +617,26 @@ class RegionBuilderHelper {
         llvm_unreachable("unsupported non numeric type");
       return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
     }
+    if (emitError) {
+      emitError() << "unsupported ternary function";
+      return nullptr;
+    }
     llvm_unreachable("unsupported ternary function");
   }
 
   // Build the type functions defined by OpDSL.
-  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
+  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
+                    function_ref<InFlightDiagnostic()> emitError = {}) {
     switch (typeFn) {
     case TypeFn::cast_signed:
       return cast(toType, operand, false);
     case TypeFn::cast_unsigned:
       return cast(toType, operand, true);
     }
+    if (emitError) {
+      emitError() << "unsupported type conversion function";
+      return nullptr;
+    }
     llvm_unreachable("unsupported type conversion function");
   }
 
@@ -617,6 +677,13 @@ class RegionBuilderHelper {
     OpBuilder::InsertionGuard g(builder);
     builder.setInsertionPointToEnd(&block);
     auto loc = operand.getLoc();
+    if (isa<UnknownLoc>(loc)) {
+      if (operand.getDefiningOp())
+        loc = operand.getDefiningOp()->getLoc();
+      else if (operand.getParentBlock() &&
+               operand.getParentBlock()->getParentOp())
+        loc = operand.getParentBlock()->getParentOp()->getLoc();
+    }
     return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
   }
 
@@ -3664,9 +3731,15 @@ bool MatmulOp::hasUserDefinedMaps() {
 /// Implements the block region builder for the MatmulOp. This is called by
 /// 'fillStructuredOpRegion'.
 void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
-                             ArrayRef<NamedAttribute> attrs) {
-  assert(3 > 0 && block.getNumArguments() == 3 &&
-         "MatmulOp regionBuilder expects 3 (>=0) args");
+                             ArrayRef<NamedAttribute> attrs,
+                             function_ref<InFlightDiagnostic()> emitError) {
+  if (emitError && block.getNumArguments() != 3) {
+    emitError() << "MatmulOp regionBuilder expects 3 args, got "
+                << block.getNumArguments();
+    return;
+  }
+  assert(block.getNumArguments() == 3 &&
+         "MatmulOp regionBuilder expects 3 args");
   RegionBuilderHelper helper(b, block);
   SmallVector<Value> yields;
 
@@ -3683,9 +3756,13 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
                                     block.getArgument(0));
   Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
                                     block.getArgument(1));
-  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
-  Value value4 =
-      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, emitError);
+  if (!value3)
+    return;
+  Value value4 = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
+                                      value3, emitError);
+  if (!value4)
+    return;
   yields.push_back(value4);
   helper.yieldOutputs(yields);
 }
@@ -3813,7 +3890,13 @@ unsigned ContractOp::getNumRegionArgs() { return 3; }
 
 /// Implement block region builder, which is called by 'fillStructuredOpRegion'.
 void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
-                               ArrayRef<NamedAttribute> attrs) {
+                               ArrayRef<NamedAttribute> attrs,
+                               function_ref<InFlightDiagnostic()> emitError) {
+  if (emitError && block.getNumArguments() != 3) {
+    emitError() << "ContractOp regionBuilder expects 3 args, got "
+                << block.getNumArguments();
+    return;
+  }
   assert(block.getNumArguments() == 3 &&
          "ContractOp regionBuilder expects 3 args");
   RegionBuilderHelper helper(b, block);
@@ -3833,10 +3916,14 @@ void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
       helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
   Value rhsAtOutType =
       helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
-  Value productAtOutType =
-      helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
+  Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
+                                                rhsAtOutType, emitError);
+  if (!productAtOutType)
+    return;
   Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
-                                      productAtOutType);
+                                      productAtOutType, emitError);
+  if (!result)
+    return;
   helper.yieldOutputs({result});
 }
 
@@ -4028,10 +4115,16 @@ bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
   return isValid;
 }
 
-void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
-                                  ArrayRef<NamedAttribute> attrs) {
+void BatchMatmulOp::regionBuilder(
+    ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
+    function_ref<InFlightDiagnostic()> emitError) {
+  if (emitError && block.getNumArguments() != 3) {
+    emitError() << "BatchMatmulOp regionBuilder expects 3 args, got "
+                << block.getNumArguments();
+    return;
+  }
   assert(block.getNumArguments() == 3 &&
-         "BatchMatmulOp regionBuilder expects 3 (>=0) args");
+         "BatchMatmulOp regionBuilder expects 3 args");
   RegionBuilderHelper helper(b, block);
   SmallVector<Value> yields;
 
@@ -4303,8 +4396,9 @@ LogicalResult ElementwiseOp::verify() {
 
 /// Implements the block region builder for the ElementwiseOp. This is called by
 /// 'fillStructuredOpRegion'.
-void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
-                                  ArrayRef<NamedAttribute> attrs) {
+void ElementwiseOp::regionBuilder(
+    ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
+    function_ref<InFlightDiagnostic()> emitError) {
   ElementwiseKind elemwiseKind;
   for (auto attr : attrs) {
     if (attr.getName() == b.getStringAttr("kind")) {
@@ -4318,6 +4412,13 @@ void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
   ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
   auto arityGroup = groupAndKind.arityGroup;
   auto kind = groupAndKind.kind;
+  if (emitError && block.getNumArguments() !=
+                       getArityGroupAsUInt(arityGroup) + 1 /*output*/) {
+    emitError() << "Elementwise regionBuilder expects "
+                << (getArityGroupAsUInt(arityGroup) + 1) << " args, got "
+                << block.getNumArguments();
+    return;
+  }
   assert(block.getNumArguments() ==
              getArityGroupAsUInt(arityGroup) + 1 /*output*/
          && "Elementwise regionBuilder number of block args mismatch");
@@ -5501,10 +5602,16 @@ bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
   return isValid;
 }
 
-void BatchReduceMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
-                                        ArrayRef<NamedAttribute> attrs) {
+void BatchReduceMatmulOp::regionBuilder(
+    ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
+    function_ref<InFlightDiagnostic()> emitError) {
+  if (emitError && block.getNumArguments() != 3) {
+    emitError() << "BatchReduceMatmulOp regionBuilder expects 3 args, got "
+                << block.getNumArguments();
+    return;
+  }
   assert(block.getNumArguments() == 3 &&
-         "BatchReduceMatmulOp regionBuilder expects 3 (>=0) args");
+         "BatchReduceMatmulOp regionBuilder expects 3 args");
   RegionBuilderHelper helper(b, block);
   SmallVector<Value> yields;
 
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index ca40301f04fa1..775951fe3864a 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1868,9 +1868,49 @@ func.func @unpack_static_inner_tile_size_and_dynamic_output_shape(
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// linalg.reduce
+//===----------------------------------------------------------------------===//
+
+
 func.func @reduce_non_operation_name(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<f32> {
   // expected-error @below {{expected bare identifier or keyword}}
   %0 = linalg.reduce {@reduce_fusion_elementwise} ins(
     %arg0: tensor<4xf32>) outs(%arg1: tensor<f32>) dimensions = [0]
   return %0 : tensor<f32>
 }
+
+// -----
+
+
+//===----------------------------------------------------------------------===//
+// Named op error checking.
+//===----------------------------------------------------------------------===//
+
+module {
+  func.func @add_invalid_mixed_types(%in_f32: memref<3xf32>, %in_i32 : memref< 3xi32>, %out_f32: memref<3xf32>, %arg3: memref<3xf32>) {
+    // expected-error @below {{Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got 'f32' and 'i32'}}
+    linalg.add ins(%in_f32, %in_i32 : memref<3xf32>, memref< 3xi32>) outs(%out_f32 : memref<3xf32>)
+    return
+  }
+}
+
+// -----
+
+func.func @elemwise_unary_invalid_mixed_types(%arg0 : tensor<?xi32>) -> tensor<?xi32> {
+  // expected-error @below {{unsupported non numeric type}}
+  %0 = linalg.elemwise_unary ins(%arg0 : tensor<?xi32>) outs(%arg0 : tensor<?xi32>) -> tensor<?xi32>
+  return %0 : tensor<?xi32>
+}
+
+// -----
+
+func.func @matmul_invalid_mixed_types(%t: tensor<?xf16>, %f: vector<4xf16>)
+  -> (tensor<?xf16>, vector<4xf16>)
+{
+  // expected-warning @unknown {{could not cast operand of type 'f16' to 'vector<4xf16>'}}
+  // expected-error @below {{Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got 'vector<4xf16>' and 'f16'}}
+  %0 = linalg.matmul ins(%t, %t : tensor<?xf16>, tensor<?xf16>)
+                                outs(%f : vector<4xf16>) -> tensor<?xf16>
+  func.return %0, %f : tensor<?xf16>, vector<4xf16>
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5fc7c33f4fb2b..1c961d272f192 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2737,12 +2737,14 @@ def TestLinalgConvOp :
     bool hasIndexSemantics() { return false; }
 
     static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
-                              mlir::ArrayRef<mlir::NamedAttribute> attrs) {
+                              mlir::ArrayRef<mlir::NamedAttribute> attrs,
+                              llvm::function_ref<mlir::InFlightDiagnostic()> emitError) {
       b.create<mlir::linalg::YieldOp>(block.getArguments().back());
     }
 
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
-                              mlir::ArrayRef<mlir::NamedAttribute>)>
+                              mlir::ArrayRef<mlir::NamedAttribute>,
+                              llvm::function_ref<mlir::InFlightDiagnostic()>)>
     getRegionBuilder() {
       return ®ionBuilder;
     }
@@ -2798,12 +2800,14 @@ def TestLinalgFillOp :
     bool hasIndexSemantics() { return false; }
 
     static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
-                              mlir::ArrayRef<mlir::NamedAttribute> attrs) {
+                              mlir::ArrayRef<mlir::NamedAttribute> attrs,
+                              llvm::function_ref<mlir::InFlightDiagnostic()> emitError) {
       b.create<mlir::linalg::YieldOp>(block.getArguments().back());
     }
 
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
-                              mlir::ArrayRef<mlir::NamedAttribute>)>
+                              mlir::ArrayRef<mlir::NamedAttribute>,
+                              llvm::function_ref<mlir::InFlightDiagnostic()>)>
     getRegionBuilder() {
       return ®ionBuilder;
     }
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index ab7b86125f693..00c70705cbb35 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -87,7 +87,8 @@ structured_op: !LinalgStructuredOpConfig
 #  ODS-NEXT:    }
 
 # IMPL-LABEL:  void Test1Op::regionBuilder(ImplicitLocOpBuilder &b,
-#  IMPL-NEXT:    Block &block, ArrayRef<NamedAttribute> attrs)
+#  IMPL-NEXT:    Block &block, ArrayRef<NamedAttribute> attrs,
+#  IMPL-NEXT:    function_ref<InFlightDiagnostic()> emitError)
 #       IMPL:  TypeFn castVal = TypeFn::cast_signed;
 #  IMPL-NEXT:  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
 #  IMPL-NEXT:                                return attr.getName() == "cast"; });
@@ -97,10 +98,10 @@ structured_op: !LinalgStructuredOpConfig
 #  IMPL-NEXT:  }
 
 #       IMPL:  Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64");
-#   IMPL-DAG:  Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]]);
+#   IMPL-DAG:  Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]], emitError);
 #   IMPL-DAG:  Value [[VAL2:[a-z0-9]+]] = helper.index(1);
-#   IMPL-DAG:  Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]]);
-#   IMPL-DAG:  Value [[VAL4:[a-z0-9]+]] = helper.buildBinaryFn(BinaryFn::add, [[VAL1]], [[VAL3]]);
+#   IMPL-DAG:  Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]], emitError);
+#   IMPL-DAG:  Value [[VAL4:[a-z0-9]+]] = helper.buildBinaryFn(BinaryFn::add, [[VAL1]], [[VAL3]], emitError);
 
 
 # @linalg_structured_op
@@ -186,7 +187,8 @@ structured_op: !LinalgStructuredOpConfig
 #       IMPL:  "incorrect element type for index attribute 'strides'"
 #       IMPL:  "incorrect shape for index attribute 'strides'"
 #       IMPL:  void Test2Op::regionBuilder(ImplicitLocOpBuilder &b,
-#  IMPL-NEXT:    Block &block, ArrayRef<NamedAttribute> attrs)
+#  IMPL-NEXT:    Block &block, ArrayRef<NamedAttribute> attrs,
+#  IMPL-NEXT:    function_ref<InFlightDiagnostic()> emitError)
 #  IMPL-NEXT:    assert(2 > 0 && block.getNumArguments() == 2 &&
 
 #       IMPL:   yields.push_back(block.getArgument(0));
@@ -315,13 +317,18 @@ structured_op: !LinalgStructuredOpConfig
 #  ODS-NEXT:    $_state.addAttribute("binary_fun", binary_fun)
 
 # IMPL-LABEL:  void Test4Op::regionBuilder(ImplicitLocOpBuilder &b,
-#  IMPL-NEXT:    Block &block, ArrayRef<NamedAttribute> attrs)
+#  IMPL-NEXT:    Block &block, ArrayRef<NamedAttribute> attrs,
+#  IMPL-NEXT:    function_ref<InFlightDiagnostic()> emitError)
 #       IMPL:  UnaryFn unary_funVal = UnaryFn::exp
 #       IMPL:  BinaryFn binary_funVal = BinaryFn::add
 
-#       IMPL:  Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0))
-#  IMPL-NEXT:  Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0))
-#  IMPL-NEXT:  yields.push_back([[VAL1]])
+#       IMPL:  Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0), emitError);
+#  IMPL-NEXT:  if (![[VAL0]])
+#  IMPL-NEXT:    return;
+#  IMPL:  Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0), emitError);
+#  IMPL-NEXT:  if (![[VAL1]])
+#  IMPL-NEXT:    return;
+#  IMPL:  yields.push_back([[VAL1]])
 
 # @linalg_structured_op
 # def test5(value=ScalarDef(T1), O=TensorDef(U, output=True)):
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 93a300e0b24a2..0a1693cff1d36 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -559,9 +559,10 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
       SmallVector<utils::IteratorType> getIteratorTypesArray();
       ArrayAttr getIndexingMaps();
       static void regionBuilder(ImplicitLocOpBuilder &b,
-                                Block &block, ArrayRef<NamedAttribute> attrs);
+                                Block &block, ArrayRef<NamedAttribute> attrs,
+                                function_ref<InFlightDiagnostic()> emitError);
       static std::function<void(ImplicitLocOpBuilder &,
-                                Block &, ArrayRef<NamedAttribute>)>
+                                Block &, ArrayRef<NamedAttribute>, function_ref<InFlightDiagnostic()> emitError)>
       getRegionBuilder() {{
         return regionBuilder;
       }
@@ -1010,7 +1011,8 @@ LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
     // {3}: Statements
     static const char structuredOpRegionBuilderFormat[] = R"FMT(
 void {0}::regionBuilder(ImplicitLocOpBuilder &b,
-                        Block &block, ArrayRef<NamedAttribute> attrs) {{
+                        Block &block, ArrayRef<NamedAttribute> attrs,
+                        function_ref<InFlightDiagnostic()> emitError) {{
   assert({1} > 0 && block.getNumArguments() == {1} &&
          "{0} regionBuilder expects {1} (>=0) args");
   RegionBuilderHelper helper(b, block);
@@ -1137,8 +1139,13 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
           // Call the function builder.
           std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
           stmts.push_back(llvm::formatv(
-              "Value {0} = helper.build{1}({2}, {3});", cppIdent, enumName,
-              funcType, interleaveToString(operandCppValues, ", ")));
+              R"mlir(
+              Value {0} = helper.build{1}({2}, {3}, emitError);
+              if (!{0})
+                return;
+              )mlir",
+              cppIdent, enumName, funcType,
+              interleaveToString(operandCppValues, ", ")));
           return cppIdent;
         }
         emitError(genContext.getLoc()) << "unknown ScalarExpression type";



More information about the Mlir-commits mailing list