[Mlir-commits] [mlir] c3f0efe - [mlir][Linalg] Fix crash in LinalgToStandard

Nicolas Vasilache llvmlistbot at llvm.org
Thu Jan 19 23:32:50 PST 2023


Author: Nicolas Vasilache
Date: 2023-01-19T23:29:19-08:00
New Revision: c3f0efe753e27105b519ae9283796d41fe574741

URL: https://github.com/llvm/llvm-project/commit/c3f0efe753e27105b519ae9283796d41fe574741
DIFF: https://github.com/llvm/llvm-project/commit/c3f0efe753e27105b519ae9283796d41fe574741.diff

LOG: [mlir][Linalg] Fix crash in LinalgToStandard

Use rewriter.notifyMatchFailure instead of assert.

Fixes #59986.

Added: 
    

Modified: 
    mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index 600c38174d33c..57419b7f1073a 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -52,14 +52,12 @@ static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
 
 // Get a SymbolRefAttr containing the library function name for the LinalgOp.
 // If the library function does not exist, insert a declaration.
-static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
-                                                 PatternRewriter &rewriter) {
+static FailureOr<FlatSymbolRefAttr>
+getLibraryCallSymbolRef(Operation *op, PatternRewriter &rewriter) {
   auto linalgOp = cast<LinalgOp>(op);
   auto fnName = linalgOp.getLibraryCallName();
-  if (fnName.empty()) {
-    op->emitWarning("No library call defined for: ") << *op;
-    return {};
-  }
+  if (fnName.empty())
+    return rewriter.notifyMatchFailure(op, "No library call defined for: ");
 
   // fnName is a dynamic std::string, unique it via a SymbolRefAttr.
   FlatSymbolRefAttr fnNameAttr =
@@ -69,9 +67,12 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
     return fnNameAttr;
 
   SmallVector<Type, 4> inputTypes(extractOperandTypes(op));
-  assert(op->getNumResults() == 0 &&
-         "Library call for linalg operation can be generated only for ops that "
-         "have void return types");
+  if (op->getNumResults() != 0) {
+    return rewriter.notifyMatchFailure(
+        op,
+        "Library call for linalg operation can be generated only for ops that "
+        "have void return types");
+  }
   auto libFnType = rewriter.getFunctionType(inputTypes, {});
 
   OpBuilder::InsertionGuard guard(rewriter);
@@ -110,13 +111,13 @@ createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
 LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
     LinalgOp op, PatternRewriter &rewriter) const {
   auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
-  if (!libraryCallName)
+  if (failed(libraryCallName))
     return failure();
 
   // TODO: Add support for more complex library call signatures that include
   // indices or captured values.
   rewriter.replaceOpWithNewOp<func::CallOp>(
-      op, libraryCallName.getValue(), TypeRange(),
+      op, libraryCallName->getValue(), TypeRange(),
       createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(),
                                             op->getOperands()));
   return success();


        


More information about the Mlir-commits mailing list