[Mlir-commits] [mlir] 5a51a44 - [mlir][llvm] Improve error message when translating `llvm.call_intrinsic`

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 28 00:17:28 PDT 2023


Author: Mogball
Date: 2023-07-28T07:17:21Z
New Revision: 5a51a44f82497b089337cfd6c3d86e3d7e3e0041

URL: https://github.com/llvm/llvm-project/commit/5a51a44f82497b089337cfd6c3d86e3d7e3e0041
DIFF: https://github.com/llvm/llvm-project/commit/5a51a44f82497b089337cfd6c3d86e3d7e3e0041.diff

LOG: [mlir][llvm] Improve error message when translating `llvm.call_intrinsic`

This is more user-friendly over an opaque crash.

Reviewed By: lattner

Differential Revision: https://reviews.llvm.org/D156475

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
    mlir/test/Dialect/LLVMIR/call-intrin.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index c144afbd5de69b..990af1d3d61fe0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -1080,31 +1080,6 @@ def LLVM_vector_extract
   }];
 }
 
-//===--------------------------------------------------------------------===//
-// CallIntrinsicOp
-//===--------------------------------------------------------------------===//
-
-def LLVM_CallIntrinsicOp
-    : LLVM_Op<"call_intrinsic",
-              [DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
-  let summary = "Call to an LLVM intrinsic function.";
-  let description = [{
-    Call the specified llvm intrinsic. If the intrinsic is overloaded, use
-    the MLIR function type of this op to determine which intrinsic to call.
-    }];
-  let arguments = (ins StrAttr:$intrin, Variadic<LLVM_Type>:$args,
-                       DefaultValuedAttr<LLVM_FastmathFlagsAttr,
-                                         "{}">:$fastmathFlags);
-  let results = (outs Variadic<LLVM_Type>:$results);
-  let llvmBuilder = [{
-    return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
-  }];
-  let assemblyFormat = [{
-    $intrin `(` $args `)` `:` functional-type($args, $results)
-      custom<LLVMOpAttrs>(attr-dict)
-  }];
-}
-
 //
 // LLVM Vector Predication operations.
 //

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index f5c76ff448506c..6fb422fea5b6d8 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1759,4 +1759,30 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", []> {
   }];
 }
 
+//===--------------------------------------------------------------------===//
+// CallIntrinsicOp
+//===--------------------------------------------------------------------===//
+
+def LLVM_CallIntrinsicOp
+    : LLVM_Op<"call_intrinsic",
+              [DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
+  let summary = "Call to an LLVM intrinsic function.";
+  let description = [{
+    Call the specified llvm intrinsic. If the intrinsic is overloaded, use
+    the MLIR function type of this op to determine which intrinsic to call.
+  }];
+  let arguments = (ins StrAttr:$intrin, Variadic<LLVM_Type>:$args,
+                       DefaultValuedAttr<LLVM_FastmathFlagsAttr,
+                                         "{}">:$fastmathFlags);
+  let results = (outs Optional<LLVM_Type>:$results);
+  let llvmBuilder = [{
+    return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
+  }];
+  let assemblyFormat = [{
+    $intrin `(` $args `)` `:` functional-type($args, $results) attr-dict
+  }];
+
+  let hasVerifier = 1;
+}
+
 #endif // LLVMIR_OPS

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 8e5b137da7a32d..2892312b5406d8 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -10,6 +10,7 @@
 // MLIR, and the LLVM IR dialect.  It also registers the dialect.
 //
 //===----------------------------------------------------------------------===//
+
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "LLVMInlining.h"
 #include "TypeDetail.h"
@@ -2785,6 +2786,59 @@ OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) {
   return IntegerAttr::get(getType(), lhs.getValue() | rhs.getValue());
 }
 
+//===----------------------------------------------------------------------===//
+// Utilities for LLVM::MetadataOp
+//===----------------------------------------------------------------------===//
+
+void MetadataOp::build(OpBuilder &builder, OperationState &result,
+                       StringRef symName, bool createBodyBlock,
+                       ArrayRef<NamedAttribute> attributes) {
+  result.addAttribute(getSymNameAttrName(result.name),
+                      builder.getStringAttr(symName));
+  result.attributes.append(attributes.begin(), attributes.end());
+  Region *body = result.addRegion();
+  if (createBodyBlock)
+    body->emplaceBlock();
+}
+
+ParseResult MetadataOp::parse(OpAsmParser &parser, OperationState &result) {
+  StringAttr symName;
+  if (parser.parseSymbolName(symName, getSymNameAttrName(result.name),
+                             result.attributes) ||
+      parser.parseOptionalAttrDictWithKeyword(result.attributes))
+    return failure();
+
+  Region *bodyRegion = result.addRegion();
+  if (parser.parseRegion(*bodyRegion))
+    return failure();
+
+  // If the region appeared to be empty to parseRegion(),
+  // add the body block explicitly.
+  if (bodyRegion->empty())
+    bodyRegion->emplaceBlock();
+
+  return success();
+}
+
+void MetadataOp::print(OpAsmPrinter &printer) {
+  printer << ' ';
+  printer.printSymbolName(getSymName());
+  printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
+                                           {getSymNameAttrName().getValue()});
+  printer << ' ';
+  printer.printRegion(getBody());
+}
+
+//===----------------------------------------------------------------------===//
+// CallIntrinsicOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult CallIntrinsicOp::verify() {
+  if (!getIntrin().startswith("llvm."))
+    return emitOpError() << "intrinsic name must start with 'llvm.'";
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // OpAsmDialectInterface
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 8f7c5d8b799e27..11e841b2e5dc06 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -58,11 +58,19 @@ static SmallVector<unsigned> extractPosition(ArrayRef<int64_t> indices) {
   return position;
 }
 
+/// Convert an LLVM type to a string for printing in diagnostics.
+static std::string diagStr(const llvm::Type *type) {
+  std::string str;
+  llvm::raw_string_ostream os(str);
+  type->print(os);
+  return os.str();
+}
+
 /// Get the declaration of an overloaded llvm intrinsic. First we get the
 /// overloaded argument types and/or result type from the CallIntrinsicOp, and
 /// then use those to get the correct declaration of the overloaded intrinsic.
 static FailureOr<llvm::Function *>
-getOverloadedDeclaration(CallIntrinsicOp &op, llvm::Intrinsic::ID id,
+getOverloadedDeclaration(CallIntrinsicOp op, llvm::Intrinsic::ID id,
                          llvm::Module *module,
                          LLVM::ModuleTranslation &moduleTranslation) {
   SmallVector<llvm::Type *, 8> allArgTys;
@@ -86,7 +94,9 @@ getOverloadedDeclaration(CallIntrinsicOp &op, llvm::Intrinsic::ID id,
   if (llvm::Intrinsic::matchIntrinsicSignature(ft, tableRef,
                                                overloadedArgTys) !=
       llvm::Intrinsic::MatchIntrinsicTypesResult::MatchIntrinsicTypes_Match) {
-    return op.emitOpError("intrinsic type is not a match");
+    return mlir::emitError(op.getLoc(), "call intrinsic signature ")
+           << diagStr(ft) << " to overloaded intrinsic " << op.getIntrinAttr()
+           << " does not match any of the overloads";
   }
 
   ArrayRef<llvm::Type *> overloadedArgTysRef = overloadedArgTys;
@@ -101,8 +111,8 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
   llvm::Intrinsic::ID id =
       llvm::Function::lookupIntrinsicID(op.getIntrinAttr());
   if (!id)
-    return op.emitOpError()
-           << "couldn't find intrinsic: " << op.getIntrinAttr();
+    return mlir::emitError(op.getLoc(), "could not find LLVM intrinsic: ")
+           << op.getIntrinAttr();
 
   llvm::Function *fn = nullptr;
   if (llvm::Intrinsic::isOverloaded(id)) {
@@ -114,6 +124,44 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
   } else {
     fn = llvm::Intrinsic::getDeclaration(module, id, {});
   }
+
+  // Check the result type of the call.
+  const llvm::Type *intrinType =
+      op.getNumResults() == 0
+          ? llvm::Type::getVoidTy(module->getContext())
+          : moduleTranslation.convertType(op.getResultTypes().front());
+  if (intrinType != fn->getReturnType()) {
+    return mlir::emitError(op.getLoc(), "intrinsic call returns ")
+           << diagStr(intrinType) << " but " << op.getIntrinAttr()
+           << " actually returns " << diagStr(fn->getReturnType());
+  }
+
+  // Check the argument types of the call. If the function is variadic, check
+  // the subrange of required arguments.
+  if (!fn->getFunctionType()->isVarArg() &&
+      op.getNumOperands() != fn->arg_size()) {
+    return mlir::emitError(op.getLoc(), "intrinsic call has ")
+           << op.getNumOperands() << " operands but " << op.getIntrinAttr()
+           << " expects " << fn->arg_size();
+  }
+  if (fn->getFunctionType()->isVarArg() &&
+      op.getNumOperands() < fn->arg_size()) {
+    return mlir::emitError(op.getLoc(), "intrinsic call has ")
+           << op.getNumOperands() << " operands but variadic "
+           << op.getIntrinAttr() << " expects at least " << fn->arg_size();
+  }
+  // Check the arguments up to the number the function requires.
+  for (unsigned i = 0, e = fn->arg_size(); i != e; ++i) {
+    const llvm::Type *expected = fn->getArg(i)->getType();
+    const llvm::Type *actual =
+        moduleTranslation.convertType(op.getOperandTypes()[i]);
+    if (actual != expected) {
+      return mlir::emitError(op.getLoc(), "intrinsic call operand #")
+             << i << " has type " << diagStr(actual) << " but "
+             << op.getIntrinAttr() << " expects " << diagStr(expected);
+    }
+  }
+
   FastmathFlagsInterface itf = op;
   builder.setFastMathFlags(getFastmathFlags(itf));
 

diff  --git a/mlir/test/Dialect/LLVMIR/call-intrin.mlir b/mlir/test/Dialect/LLVMIR/call-intrin.mlir
index 1b8cd54ab26176..24aa38fca4a658 100644
--- a/mlir/test/Dialect/LLVMIR/call-intrin.mlir
+++ b/mlir/test/Dialect/LLVMIR/call-intrin.mlir
@@ -1,82 +1,107 @@
 // RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s
 
-// CHECK: ; ModuleID = 'LLVMDialectModule'
-// CHECK: source_filename = "LLVMDialectModule"
-// CHECK: declare ptr @malloc(i64)
-// CHECK: declare void @free(ptr)
 // CHECK: define <4 x float> @round_sse41() {
-// CHECK:  %1 = call reassoc <4 x float> @llvm.x86.sse41.round.ss(<4 x float> <float 0x3FC99999A0000000, float 0x3FC99999A0000000, float 0x3FC99999A0000000, float 0x3FC99999A0000000>, <4 x float> <float 0x3FC99999A0000000, float 0x3FC99999A0000000, float 0x3FC99999A0000000, float 0x3FC99999A0000000>, i32 1)
+// CHECK:  %1 = call reassoc <4 x float> @llvm.x86.sse41.round.ss(<4 x float> {{.*}}, <4 x float> {{.*}}, i32 1)
 // CHECK:  ret <4 x float> %1
 // CHECK: }
 llvm.func @round_sse41() -> vector<4xf32> {
-    %0 = llvm.mlir.constant(1 : i32) : i32
-    %1 = llvm.mlir.constant(dense<0.2> : vector<4xf32>) : vector<4xf32>
-    %res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) : (vector<4xf32>, vector<4xf32>, i32) -> vector<4xf32> {fastmathFlags = #llvm.fastmath<reassoc>}
-    llvm.return %res: vector<4xf32>
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.mlir.constant(dense<0.2> : vector<4xf32>) : vector<4xf32>
+  %res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) : (vector<4xf32>, vector<4xf32>, i32) -> vector<4xf32> {fastmathFlags = #llvm.fastmath<reassoc>}
+  llvm.return %res: vector<4xf32>
 }
 
 // -----
 
-// CHECK: ; ModuleID = 'LLVMDialectModule'
-// CHECK: source_filename = "LLVMDialectModule"
-
-// CHECK: declare ptr @malloc(i64)
-
-// CHECK: declare void @free(ptr)
-
 // CHECK: define float @round_overloaded() {
 // CHECK:   %1 = call float @llvm.round.f32(float 1.000000e+00)
 // CHECK:   ret float %1
 // CHECK: }
 llvm.func @round_overloaded() -> f32 {
-    %0 = llvm.mlir.constant(1.0 : f32) : f32
-    %res = llvm.call_intrinsic "llvm.round"(%0) : (f32) -> f32 {}
-    llvm.return %res: f32
+  %0 = llvm.mlir.constant(1.0 : f32) : f32
+  %res = llvm.call_intrinsic "llvm.round"(%0) : (f32) -> f32 {}
+  llvm.return %res: f32
 }
 
 // -----
 
-// CHECK: ; ModuleID = 'LLVMDialectModule'
-// CHECK: source_filename = "LLVMDialectModule"
-// CHECK: declare ptr @malloc(i64)
-// CHECK: declare void @free(ptr)
 // CHECK: define void @lifetime_start() {
 // CHECK:   %1 = alloca float, i8 1, align 4
 // CHECK:   call void @llvm.lifetime.start.p0(i64 4, ptr %1)
 // CHECK:   ret void
 // CHECK: }
 llvm.func @lifetime_start() {
-    %0 = llvm.mlir.constant(4 : i64) : i64
-    %1 = llvm.mlir.constant(1 : i8) : i8
-    %2 = llvm.alloca %1 x f32 : (i8) -> !llvm.ptr
-    llvm.call_intrinsic "llvm.lifetime.start"(%0, %2) : (i64, !llvm.ptr) -> () {}
-    llvm.return
+  %0 = llvm.mlir.constant(4 : i64) : i64
+  %1 = llvm.mlir.constant(1 : i8) : i8
+  %2 = llvm.alloca %1 x f32 : (i8) -> !llvm.ptr
+  llvm.call_intrinsic "llvm.lifetime.start"(%0, %2) : (i64, !llvm.ptr) -> () {}
+  llvm.return
 }
 
 // -----
 
+// CHECK-LABEL: define void @variadic()
 llvm.func @variadic() {
-    %0 = llvm.mlir.constant(1 : i8) : i8
-    %1 = llvm.alloca %0 x f32 : (i8) -> !llvm.ptr
-    llvm.call_intrinsic "llvm.localescape"(%1, %1) : (!llvm.ptr, !llvm.ptr) -> ()
-    llvm.return
+  %0 = llvm.mlir.constant(1 : i8) : i8
+  %1 = llvm.alloca %0 x f32 : (i8) -> !llvm.ptr
+  // CHECK: call void (...) @llvm.localescape(ptr %1, ptr %1)
+  llvm.call_intrinsic "llvm.localescape"(%1, %1) : (!llvm.ptr, !llvm.ptr) -> ()
+  llvm.return
 }
 
 // -----
 
 llvm.func @no_intrinsic() {
-    // expected-error at below {{'llvm.call_intrinsic' op couldn't find intrinsic: "llvm.does_not_exist"}}
-    // expected-error at below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
-    llvm.call_intrinsic "llvm.does_not_exist"() : () -> ()
-    llvm.return
+  // expected-error at below {{could not find LLVM intrinsic: "llvm.does_not_exist"}}
+  // expected-error at below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
+  llvm.call_intrinsic "llvm.does_not_exist"() : () -> ()
+  llvm.return
 }
 
 // -----
 
 llvm.func @bad_types() {
-    %0 = llvm.mlir.constant(1 : i8) : i8
-    // expected-error at below {{'llvm.call_intrinsic' op intrinsic type is not a match}}
-    // expected-error at below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
-    llvm.call_intrinsic "llvm.round"(%0) : (i8) -> i8 {}
-    llvm.return
+  %0 = llvm.mlir.constant(1 : i8) : i8
+  // expected-error at below {{call intrinsic signature i8 (i8) to overloaded intrinsic "llvm.round" does not match any of the overloads}}
+  // expected-error at below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
+  llvm.call_intrinsic "llvm.round"(%0) : (i8) -> i8 {}
+  llvm.return
+}
+
+// -----
+
+llvm.func @bad_result() {
+  // expected-error @below {{intrinsic call returns void but "llvm.x86.sse41.round.ss" actually returns <4 x float>}}
+  // expected-error at below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
+  llvm.call_intrinsic "llvm.x86.sse41.round.ss"() : () -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @bad_result() {
+  // expected-error @below {{intrinsic call returns <8 x float> but "llvm.x86.sse41.round.ss" actually returns <4 x float>}}
+  // expected-error at below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
+  llvm.call_intrinsic "llvm.x86.sse41.round.ss"() : () -> (vector<8xf32>)
+  llvm.return
+}
+
+// -----
+
+llvm.func @bad_args() {
+  // expected-error @below {{intrinsic call has 0 operands but "llvm.x86.sse41.round.ss" expects 3}}
+  // expected-error at below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
+  llvm.call_intrinsic "llvm.x86.sse41.round.ss"() : () -> (vector<4xf32>)
+  llvm.return
+}
+
+// -----
+
+llvm.func @bad_args() {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.mlir.constant(dense<0.2> : vector<4xf32>) : vector<4xf32>
+  // expected-error @below {{intrinsic call operand #2 has type i64 but "llvm.x86.sse41.round.ss" expects i32}}
+  // expected-error at below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
+  %res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) : (vector<4xf32>, vector<4xf32>, i64) -> vector<4xf32> {fastmathFlags = #llvm.fastmath<reassoc>}
+  llvm.return
 }


        


More information about the Mlir-commits mailing list