[Mlir-commits] [mlir] cc49a74 - [mlir][llvm] Use TableGen to import compare ops from LLVM IR.

Tobias Gysi llvmlistbot at llvm.org
Thu Oct 13 05:34:07 PDT 2022


Author: Tobias Gysi
Date: 2022-10-13T15:31:04+03:00
New Revision: cc49a74a7bc6790d3502d1c2666712c1ef0c211b

URL: https://github.com/llvm/llvm-project/commit/cc49a74a7bc6790d3502d1c2666712c1ef0c211b
DIFF: https://github.com/llvm/llvm-project/commit/cc49a74a7bc6790d3502d1c2666712c1ef0c211b.diff

LOG: [mlir][llvm] Use TableGen to import compare ops from LLVM IR.

The revision imports compare operations using TableGen generated
builders, instead of using the special handlers defined by the Importer.
It therefore adds a new llvmArgIndexes field that allows to specify
a mapping between MLIR argument and LLVM IR operand indexes if they do
not match. Additionally, the FCmp op is extended with an additional
builder and all compare operations are extended with verification
traits to ensure the operands types match. These extensions simplify
the logic of the newly introduced builders and are in line with the
compare operations define by the arithmetic dialect.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D135855

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
    mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
    mlir/test/Target/LLVMIR/Import/basic.ll
    mlir/test/Target/LLVMIR/Import/instructions.ll
    mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 5275739f750ea..2ed9c571f4c08 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -228,9 +228,16 @@ class LLVM_OpBase<Dialect dialect, string mnemonic, list<Trait> traits = []> :
   //   - $_builder - substituted with the MLIR builder;
   //   - $_qualCppClassName - substitiuted with the MLIR operation class name.
   // Additionally, `$$` can be used to produce the dollar character.
-  // NOTE: The $name variable resolution assumes the MLIR and LLVM argument
-  // orders match and there are no optional or variadic arguments.
+  // FIXME: The $name variable resolution does not support variadic arguments.
   string mlirBuilder = "";
+
+  // An array that specifies a mapping from MLIR argument indices to LLVM IR
+  // operand indices. The mapping is necessary since argument and operand
+  // indices do not always match. If not defined, the array is set to the
+  // identity permutation. An operation may define any custom index permutation
+  // and set a specific argument index to -1 if it does not map to an LLVM IR
+  // operand.
+  list<int> llvmArgIndices = [];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index acc56a932a08b..5f798bc123a8d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -284,19 +284,39 @@ def ICmpPredicate : I64EnumAttr<
   let cppNamespace = "::mlir::LLVM";
 }
 
+// Base class for compare operations. A compare operation takes two operands
+// of the same type and returns a boolean result. If the operands are
+// vectors, then the result has to be a boolean vector of the same shape.
+class LLVM_ArithmeticCmpOp<string mnemonic, list<Trait> traits = []> :
+    LLVM_Op<mnemonic, traits # [SameTypeOperands, TypesMatchWith<
+    "result type has i1 element type and same shape as operands",
+    "lhs", "res", "::getI1SameShape($_self)">]> {
+  let results = (outs LLVM_ScalarOrVectorOf<I1>:$res);
+}
+
 // Other integer operations.
-def LLVM_ICmpOp : LLVM_Op<"icmp", [Pure]> {
+def LLVM_ICmpOp : LLVM_ArithmeticCmpOp<"icmp", [Pure]> {
   let arguments = (ins ICmpPredicate:$predicate,
-                   AnyTypeOf<[LLVM_ScalarOrVectorOf<AnyInteger>, LLVM_ScalarOrVectorOf<LLVM_AnyPointer>]>:$lhs,
-                   AnyTypeOf<[LLVM_ScalarOrVectorOf<AnyInteger>, LLVM_ScalarOrVectorOf<LLVM_AnyPointer>]>:$rhs);
-  let results = (outs LLVM_ScalarOrVectorOf<I1>:$res);
-  let llvmBuilder = [{
-    $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
-  }];
+                   AnyTypeOf<[LLVM_ScalarOrVectorOf<AnyInteger>,
+                              LLVM_ScalarOrVectorOf<LLVM_AnyPointer>]>:$lhs,
+                   AnyTypeOf<[LLVM_ScalarOrVectorOf<AnyInteger>,
+                              LLVM_ScalarOrVectorOf<LLVM_AnyPointer>]>:$rhs);
   let builders = [
     OpBuilder<(ins "ICmpPredicate":$predicate, "Value":$lhs, "Value":$rhs)>
   ];
   let hasCustomAssemblyFormat = 1;
+  string llvmInstName = "ICmp";
+  string llvmBuilder = [{
+    $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
+  }];
+  string mlirBuilder = [{
+    auto *iCmpInst = cast<llvm::ICmpInst>(inst);
+    $res = $_builder.create<$_qualCppClassName>(
+      $_location, getICmpPredicate(iCmpInst->getPredicate()), $lhs, $rhs);
+  }];
+  // Set the $predicate index to -1 to indicate there is no matching operand
+  // and decrement the following indices.
+  list<int> llvmArgIndices = [-1, 0, 1];
 }
 
 // Predicate for float comparisons
@@ -329,17 +349,29 @@ def FCmpPredicate : I64EnumAttr<
 }
 
 // Other floating-point operations.
-def LLVM_FCmpOp : LLVM_Op<"fcmp", [
+def LLVM_FCmpOp : LLVM_ArithmeticCmpOp<"fcmp", [
     Pure, DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
   let arguments = (ins FCmpPredicate:$predicate,
                    LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$lhs,
                    LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$rhs,
                    DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
-  let results = (outs LLVM_ScalarOrVectorOf<I1>:$res);
-  let llvmBuilder = [{
+  let builders = [
+    OpBuilder<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs)>
+  ];
+  let hasCustomAssemblyFormat = 1;
+  string llvmInstName = "FCmp";
+  string llvmBuilder = [{
     $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
   }];
-  let hasCustomAssemblyFormat = 1;
+  // FIXME: Import fastmath flags.
+  string mlirBuilder = [{
+    auto *fCmpInst = cast<llvm::FCmpInst>(inst);
+    $res = $_builder.create<$_qualCppClassName>(
+      $_location, getFCmpPredicate(fCmpInst->getPredicate()), $lhs, $rhs);
+  }];
+  // Set the $predicate index to -1 to indicate there is no matching operand
+  // and decrement the following indices.
+  list<int> llvmArgIndices = [-1, 0, 1, 2];
 }
 
 // Floating point binary operations.

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 1f307c2aa7372..f046b736f0218 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -521,6 +521,10 @@ bool isScalableVectorType(Type vectorType);
 Type getVectorType(Type elementType, unsigned numElements,
                    bool isScalable = false);
 
+/// Creates an LLVM dialect-compatible vector type with the given element type
+/// and length.
+Type getVectorType(Type elementType, const llvm::ElementCount &numElements);
+
 /// Creates an LLVM dialect-compatible type with the given element type and
 /// length.
 Type getFixedVectorType(Type elementType, unsigned numElements);

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 205866230913c..365febf32ef2a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -88,28 +88,27 @@ static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
   return success();
 }
 
+/// Returns a boolean type that has the same shape as `type`. It supports both
+/// fixed size vectors as well as scalable vectors.
+static Type getI1SameShape(Type type) {
+  Type i1Type = IntegerType::get(type.getContext(), 1);
+  if (LLVM::isCompatibleVectorType(type))
+    return LLVM::getVectorType(i1Type, LLVM::getVectorNumElements(type));
+  return i1Type;
+}
+
 //===----------------------------------------------------------------------===//
 // Printing, parsing and builder for LLVM::CmpOp.
 //===----------------------------------------------------------------------===//
 
 void ICmpOp::build(OpBuilder &builder, OperationState &result,
                    ICmpPredicate predicate, Value lhs, Value rhs) {
-  auto boolType = IntegerType::get(lhs.getType().getContext(), 1);
-  if (LLVM::isCompatibleVectorType(lhs.getType()) ||
-      LLVM::isCompatibleVectorType(rhs.getType())) {
-    int64_t numLHSElements = 1, numRHSElements = 1;
-    if (LLVM::isCompatibleVectorType(lhs.getType()))
-      numLHSElements =
-          LLVM::getVectorNumElements(lhs.getType()).getFixedValue();
-    if (LLVM::isCompatibleVectorType(rhs.getType()))
-      numRHSElements =
-          LLVM::getVectorNumElements(rhs.getType()).getFixedValue();
-    build(builder, result,
-          VectorType::get({std::max(numLHSElements, numRHSElements)}, boolType),
-          predicate, lhs, rhs);
-  } else {
-    build(builder, result, boolType, predicate, lhs, rhs);
-  }
+  build(builder, result, getI1SameShape(lhs.getType()), predicate, lhs, rhs);
+}
+
+void FCmpOp::build(OpBuilder &builder, OperationState &result,
+                   FCmpPredicate predicate, Value lhs, Value rhs) {
+  build(builder, result, getI1SameShape(lhs.getType()), predicate, lhs, rhs);
 }
 
 void ICmpOp::print(OpAsmPrinter &p) {
@@ -132,8 +131,6 @@ void FCmpOp::print(OpAsmPrinter &p) {
 //                 attribute-dict? `:` type
 template <typename CmpPredicateType>
 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
-  Builder &builder = parser.getBuilder();
-
   StringAttr predicateAttr;
   OpAsmParser::UnresolvedOperand lhs, rhs;
   Type type;
@@ -173,23 +170,10 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
 
   // The result type is either i1 or a vector type <? x i1> if the inputs are
   // vectors.
-  Type resultType = IntegerType::get(builder.getContext(), 1);
   if (!isCompatibleType(type))
     return parser.emitError(trailingTypeLoc,
                             "expected LLVM dialect-compatible type");
-  if (LLVM::isCompatibleVectorType(type)) {
-    if (LLVM::isScalableVectorType(type)) {
-      resultType = LLVM::getVectorType(
-          resultType, LLVM::getVectorNumElements(type).getKnownMinValue(),
-          /*isScalable=*/true);
-    } else {
-      resultType = LLVM::getVectorType(
-          resultType, LLVM::getVectorNumElements(type).getFixedValue(),
-          /*isScalable=*/false);
-    }
-  }
-
-  result.addTypes({resultType});
+  result.addTypes(getI1SameShape(type));
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 80dcfa84f6937..279ea52aebada 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -934,6 +934,15 @@ Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
   return VectorType::get(numElements, elementType, (unsigned)isScalable);
 }
 
+Type mlir::LLVM::getVectorType(Type elementType,
+                               const llvm::ElementCount &numElements) {
+  if (numElements.isScalable())
+    return getVectorType(elementType, numElements.getKnownMinValue(),
+                         /*isScalable=*/true);
+  return getVectorType(elementType, numElements.getFixedValue(),
+                       /*isScalable=*/false);
+}
+
 Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
   bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType);
   bool useBuiltIn = VectorType::isValidElementType(elementType);

diff  --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index bbffd9681f8c4..523aebf43baac 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -804,36 +804,6 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
 
   // Convert all special instructions that do not provide an MLIR builder.
   Location loc = translateLoc(inst->getDebugLoc());
-  if (inst->getOpcode() == llvm::Instruction::ICmp) {
-    Value lhs = processValue(inst->getOperand(0));
-    Value rhs = processValue(inst->getOperand(1));
-    Value res = b.create<ICmpOp>(
-        loc, getICmpPredicate(cast<llvm::ICmpInst>(inst)->getPredicate()), lhs,
-        rhs);
-    mapValue(inst, res);
-    return success();
-  }
-  if (inst->getOpcode() == llvm::Instruction::FCmp) {
-    Value lhs = processValue(inst->getOperand(0));
-    Value rhs = processValue(inst->getOperand(1));
-
-    if (lhs.getType() != rhs.getType())
-      return failure();
-
-    Type boolType = b.getI1Type();
-    Type resType = boolType;
-    if (LLVM::isCompatibleVectorType(lhs.getType())) {
-      unsigned numElements =
-          LLVM::getVectorNumElements(lhs.getType()).getFixedValue();
-      resType = VectorType::get({numElements}, boolType);
-    }
-
-    Value res = b.create<FCmpOp>(
-        loc, resType,
-        getFCmpPredicate(cast<llvm::FCmpInst>(inst)->getPredicate()), lhs, rhs);
-    mapValue(inst, res);
-    return success();
-  }
   if (inst->getOpcode() == llvm::Instruction::Br) {
     auto *brInst = cast<llvm::BranchInst>(inst);
     OperationState state(loc,

diff  --git a/mlir/test/Target/LLVMIR/Import/basic.ll b/mlir/test/Target/LLVMIR/Import/basic.ll
index bb41a5b23e166..961c5e876f3f8 100644
--- a/mlir/test/Target/LLVMIR/Import/basic.ll
+++ b/mlir/test/Target/LLVMIR/Import/basic.ll
@@ -227,28 +227,6 @@ define i32* @f3() {
   ret i32* bitcast (double* @g2 to i32*)
 }
 
-; CHECK-LABEL: llvm.func @f5
-define void @f5(i32 %d) {
-; FIXME: icmp should return i1.
-; CHECK: = llvm.icmp "eq"
-  %1 = icmp eq i32 %d, 2
-; CHECK: = llvm.icmp "slt"
-  %2 = icmp slt i32 %d, 2
-; CHECK: = llvm.icmp "sle"
-  %3 = icmp sle i32 %d, 2
-; CHECK: = llvm.icmp "sgt"
-  %4 = icmp sgt i32 %d, 2
-; CHECK: = llvm.icmp "sge"
-  %5 = icmp sge i32 %d, 2
-; CHECK: = llvm.icmp "ult"
-  %6 = icmp ult i32 %d, 2
-; CHECK: = llvm.icmp "ule"
-  %7 = icmp ule i32 %d, 2
-; CHECK: = llvm.icmp "ugt"
-  %8 = icmp ugt i32 %d, 2
-  ret void
-}
-
 ; CHECK-LABEL: llvm.func @f6(%arg0: !llvm.ptr<func<void (i16)>>)
 define void @f6(void (i16) *%fn) {
 ; CHECK: %[[c:[0-9]+]] = llvm.mlir.constant(0 : i16) : i16
@@ -257,43 +235,6 @@ define void @f6(void (i16) *%fn) {
   ret void
 }
 
-; CHECK-LABEL: llvm.func @FPComparison(%arg0: f32, %arg1: f32)
-define void @FPComparison(float %a, float %b) {
-  ; CHECK: llvm.fcmp "_false" %arg0, %arg1
-  %1 = fcmp false float %a, %b
-  ; CHECK: llvm.fcmp "oeq" %arg0, %arg1
-  %2 = fcmp oeq float %a, %b
-  ; CHECK: llvm.fcmp "ogt" %arg0, %arg1
-  %3 = fcmp ogt float %a, %b
-  ; CHECK: llvm.fcmp "oge" %arg0, %arg1
-  %4 = fcmp oge float %a, %b
-  ; CHECK: llvm.fcmp "olt" %arg0, %arg1
-  %5 = fcmp olt float %a, %b
-  ; CHECK: llvm.fcmp "ole" %arg0, %arg1
-  %6 = fcmp ole float %a, %b
-  ; CHECK: llvm.fcmp "one" %arg0, %arg1
-  %7 = fcmp one float %a, %b
-  ; CHECK: llvm.fcmp "ord" %arg0, %arg1
-  %8 = fcmp ord float %a, %b
-  ; CHECK: llvm.fcmp "ueq" %arg0, %arg1
-  %9 = fcmp ueq float %a, %b
-  ; CHECK: llvm.fcmp "ugt" %arg0, %arg1
-  %10 = fcmp ugt float %a, %b
-  ; CHECK: llvm.fcmp "uge" %arg0, %arg1
-  %11 = fcmp uge float %a, %b
-  ; CHECK: llvm.fcmp "ult" %arg0, %arg1
-  %12 = fcmp ult float %a, %b
-  ; CHECK: llvm.fcmp "ule" %arg0, %arg1
-  %13 = fcmp ule float %a, %b
-  ; CHECK: llvm.fcmp "une" %arg0, %arg1
-  %14 = fcmp une float %a, %b
-  ; CHECK: llvm.fcmp "uno" %arg0, %arg1
-  %15 = fcmp uno float %a, %b
-  ; CHECK: llvm.fcmp "_true" %arg0, %arg1
-  %16 = fcmp true float %a, %b
-  ret void
-}
-
 ; Testing rest of the floating point constant kinds.
 ; CHECK-LABEL: llvm.func @FPConstant(%arg0: f16, %arg1: bf16, %arg2: f128, %arg3: f80)
 define void @FPConstant(half %a, bfloat %b, fp128 %c, x86_fp80 %d) {

diff  --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index 1dad3f86a81f6..d42840dda521a 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -9,38 +9,66 @@ define void @integer_arith(i32 %arg1, i32 %arg2, i64 %arg3, i64 %arg4) {
   ; CHECK-DAG:  %[[C1:[0-9]+]] = llvm.mlir.constant(-7 : i32) : i32
   ; CHECK-DAG:  %[[C2:[0-9]+]] = llvm.mlir.constant(42 : i32) : i32
   ; CHECK:  llvm.add %[[ARG1]], %[[C1]] : i32
-  ; CHECK:  llvm.add %[[C2]], %[[ARG2]] : i32
-  ; CHECK:  llvm.sub %[[ARG3]], %[[ARG4]] : i64
-  ; CHECK:  llvm.mul %[[ARG1]], %[[ARG2]] : i32
-  ; CHECK:  llvm.udiv %[[ARG3]], %[[ARG4]] : i64
-  ; CHECK:  llvm.sdiv %[[ARG1]], %[[ARG2]] : i32
-  ; CHECK:  llvm.urem %[[ARG3]], %[[ARG4]] : i64
-  ; CHECK:  llvm.srem %[[ARG1]], %[[ARG2]] : i32
-  ; CHECK:  llvm.shl %[[ARG3]], %[[ARG4]] : i64
-  ; CHECK:  llvm.lshr %[[ARG1]], %[[ARG2]] : i32
-  ; CHECK:  llvm.ashr %[[ARG3]], %[[ARG4]] : i64
-  ; CHECK:  llvm.and %[[ARG1]], %[[ARG2]] : i32
-  ; CHECK:  llvm.or %[[ARG3]], %[[ARG4]] : i64
-  ; CHECK:  llvm.xor %[[ARG1]], %[[ARG2]] : i32
   %1 = add i32 %arg1, -7
+  ; CHECK:  llvm.add %[[C2]], %[[ARG2]] : i32
   %2 = add i32 42, %arg2
+  ; CHECK:  llvm.sub %[[ARG3]], %[[ARG4]] : i64
   %3 = sub i64 %arg3, %arg4
+  ; CHECK:  llvm.mul %[[ARG1]], %[[ARG2]] : i32
   %4 = mul i32 %arg1, %arg2
+  ; CHECK:  llvm.udiv %[[ARG3]], %[[ARG4]] : i64
   %5 = udiv i64 %arg3, %arg4
+  ; CHECK:  llvm.sdiv %[[ARG1]], %[[ARG2]] : i32
   %6 = sdiv i32 %arg1, %arg2
+  ; CHECK:  llvm.urem %[[ARG3]], %[[ARG4]] : i64
   %7 = urem i64 %arg3, %arg4
+  ; CHECK:  llvm.srem %[[ARG1]], %[[ARG2]] : i32
   %8 = srem i32 %arg1, %arg2
+  ; CHECK:  llvm.shl %[[ARG3]], %[[ARG4]] : i64
   %9 = shl i64 %arg3, %arg4
+  ; CHECK:  llvm.lshr %[[ARG1]], %[[ARG2]] : i32
   %10 = lshr i32 %arg1, %arg2
+  ; CHECK:  llvm.ashr %[[ARG3]], %[[ARG4]] : i64
   %11 = ashr i64 %arg3, %arg4
+  ; CHECK:  llvm.and %[[ARG1]], %[[ARG2]] : i32
   %12 = and i32 %arg1, %arg2
+  ; CHECK:  llvm.or %[[ARG3]], %[[ARG4]] : i64
   %13 = or i64 %arg3, %arg4
+  ; CHECK:  llvm.xor %[[ARG1]], %[[ARG2]] : i32
   %14 = xor i32 %arg1, %arg2
   ret void
 }
 
 ; // -----
 
+; CHECK-LABEL: @integer_compare
+; CHECK-SAME:  %[[ARG1:[a-zA-Z0-9]+]]
+; CHECK-SAME:  %[[ARG2:[a-zA-Z0-9]+]]
+; CHECK-SAME:  %[[ARG3:[a-zA-Z0-9]+]]
+; CHECK-SAME:  %[[ARG4:[a-zA-Z0-9]+]]
+define i1 @integer_compare(i32 %arg1, i32 %arg2, <4 x i64> %arg3, <4 x i64> %arg4) {
+  ; CHECK:  llvm.icmp "eq" %[[ARG3]], %[[ARG4]] : vector<4xi64>
+  %1 = icmp eq <4 x i64> %arg3, %arg4
+  ; CHECK:  llvm.icmp "slt" %[[ARG1]], %[[ARG2]] : i32
+  %2 = icmp slt i32 %arg1, %arg2
+  ; CHECK:  llvm.icmp "sle" %[[ARG1]], %[[ARG2]] : i32
+  %3 = icmp sle i32 %arg1, %arg2
+  ; CHECK:  llvm.icmp "sgt" %[[ARG1]], %[[ARG2]] : i32
+  %4 = icmp sgt i32 %arg1, %arg2
+  ; CHECK:  llvm.icmp "sge" %[[ARG1]], %[[ARG2]] : i32
+  %5 = icmp sge i32 %arg1, %arg2
+  ; CHECK:  llvm.icmp "ult" %[[ARG1]], %[[ARG2]] : i32
+  %6 = icmp ult i32 %arg1, %arg2
+  ; CHECK:  llvm.icmp "ule" %[[ARG1]], %[[ARG2]] : i32
+  %7 = icmp ule i32 %arg1, %arg2
+  ; Verify scalar comparisons return a scalar boolean
+  ; CHECK:  llvm.icmp "ugt" %[[ARG1]], %[[ARG2]] : i32
+  %8 = icmp ugt i32 %arg1, %arg2
+  ret i1 %8
+}
+
+; // -----
+
 ; CHECK-LABEL: @fp_arith
 ; CHECK-SAME:  %[[ARG1:[a-zA-Z0-9]+]]
 ; CHECK-SAME:  %[[ARG2:[a-zA-Z0-9]+]]
@@ -50,26 +78,70 @@ define void @fp_arith(float %arg1, float %arg2, double %arg3, double %arg4) {
   ; CHECK:  %[[C1:[0-9]+]] = llvm.mlir.constant(3.030000e+01 : f64) : f64
   ; CHECK:  %[[C2:[0-9]+]] = llvm.mlir.constant(3.030000e+01 : f32) : f32
   ; CHECK:  llvm.fadd %[[C2]], %[[ARG1]] : f32
-  ; CHECK:  llvm.fadd %[[ARG1]], %[[ARG2]] : f32
-  ; CHECK:  llvm.fadd %[[C1]], %[[ARG3]] : f64
-  ; CHECK:  llvm.fsub %[[ARG1]], %[[ARG2]] : f32
-  ; CHECK:  llvm.fmul %[[ARG3]], %[[ARG4]] : f64
-  ; CHECK:  llvm.fdiv %[[ARG1]], %[[ARG2]] : f32
-  ; CHECK:  llvm.frem %[[ARG3]], %[[ARG4]] : f64
-  ; CHECK:  llvm.fneg %[[ARG1]] : f32
   %1 = fadd float 0x403E4CCCC0000000, %arg1
+  ; CHECK:  llvm.fadd %[[ARG1]], %[[ARG2]] : f32
   %2 = fadd float %arg1, %arg2
+  ; CHECK:  llvm.fadd %[[C1]], %[[ARG3]] : f64
   %3 = fadd double 3.030000e+01, %arg3
+  ; CHECK:  llvm.fsub %[[ARG1]], %[[ARG2]] : f32
   %4 = fsub float %arg1, %arg2
+  ; CHECK:  llvm.fmul %[[ARG3]], %[[ARG4]] : f64
   %5 = fmul double %arg3, %arg4
+  ; CHECK:  llvm.fdiv %[[ARG1]], %[[ARG2]] : f32
   %6 = fdiv float %arg1, %arg2
+  ; CHECK:  llvm.frem %[[ARG3]], %[[ARG4]] : f64
   %7 = frem double %arg3, %arg4
+  ; CHECK:  llvm.fneg %[[ARG1]] : f32
   %8 = fneg float %arg1
   ret void
 }
 
 ; // -----
 
+; CHECK-LABEL: @fp_compare
+; CHECK-SAME:  %[[ARG1:[a-zA-Z0-9]+]]
+; CHECK-SAME:  %[[ARG2:[a-zA-Z0-9]+]]
+; CHECK-SAME:  %[[ARG3:[a-zA-Z0-9]+]]
+; CHECK-SAME:  %[[ARG4:[a-zA-Z0-9]+]]
+define <4 x i1> @fp_compare(float %arg1, float %arg2, <4 x double> %arg3, <4 x double> %arg4) {
+  ; CHECK:  llvm.fcmp "_false" %[[ARG1]], %[[ARG2]] : f32
+  %1 = fcmp false float %arg1, %arg2
+  ; CHECK:  llvm.fcmp "oeq" %[[ARG1]], %[[ARG2]] : f32
+  %2 = fcmp oeq float %arg1, %arg2
+  ; CHECK:  llvm.fcmp "ogt" %[[ARG1]], %[[ARG2]] : f32
+  %3 = fcmp ogt float %arg1, %arg2
+  ; CHECK:  llvm.fcmp "oge" %[[ARG1]], %[[ARG2]] : f32
+  %4 = fcmp oge float %arg1, %arg2
+  ; CHECK:  llvm.fcmp "olt" %[[ARG1]], %[[ARG2]] : f32
+  %5 = fcmp olt float %arg1, %arg2
+  ; CHECK:  llvm.fcmp "ole" %[[ARG1]], %[[ARG2]] : f32
+  %6 = fcmp ole float %arg1, %arg2
+  ; CHECK:  llvm.fcmp "one" %[[ARG1]], %[[ARG2]] : f32
+  %7 = fcmp one float %arg1, %arg2
+  ; CHECK:  llvm.fcmp "ord" %[[ARG1]], %[[ARG2]] : f32
+  %8 = fcmp ord float %arg1, %arg2
+  ; CHECK:  llvm.fcmp "ueq" %[[ARG1]], %[[ARG2]] : f32
+  %9 = fcmp ueq float %arg1, %arg2
+  ; CHECK:  llvm.fcmp "ugt" %[[ARG1]], %[[ARG2]] : f32
+  %10 = fcmp ugt float %arg1, %arg2
+  ; CHECK:  llvm.fcmp "uge" %[[ARG1]], %[[ARG2]] : f32
+  %11 = fcmp uge float %arg1, %arg2
+  ; CHECK:  llvm.fcmp "ult" %[[ARG1]], %[[ARG2]] : f32
+  %12 = fcmp ult float %arg1, %arg2
+  ; CHECK:  llvm.fcmp "ule" %[[ARG1]], %[[ARG2]] : f32
+  %13 = fcmp ule float %arg1, %arg2
+  ; CHECK:  llvm.fcmp "une" %[[ARG1]], %[[ARG2]] : f32
+  %14 = fcmp une float %arg1, %arg2
+  ; CHECK:  llvm.fcmp "uno" %[[ARG1]], %[[ARG2]] : f32
+  %15 = fcmp uno float %arg1, %arg2
+  ; Verify vector comparisons return a vector of booleans
+  ; CHECK:  llvm.fcmp "_true" %[[ARG3]], %[[ARG4]] : vector<4xf64>
+  %16 = fcmp true <4 x double> %arg3, %arg4
+  ret <4 x i1> %16
+}
+
+; // -----
+
 ; CHECK-LABEL: @fp_casts
 ; CHECK-SAME:  %[[ARG1:[a-zA-Z0-9]+]]
 ; CHECK-SAME:  %[[ARG2:[a-zA-Z0-9]+]]

diff  --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
index 5f44d9e4beb6f..6d06ea199ae8b 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
@@ -17,6 +17,7 @@
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
 
+#include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -203,6 +204,19 @@ static LogicalResult emitOneMLIRBuilder(const Record &record, raw_ostream &os,
   if (builderStrRef.empty())
     return success();
 
+  // Access the argument index array that maps argument indices to LLVM IR
+  // operand indices. If the operation defines no custom mapping, set the array
+  // to the identity permutation.
+  std::vector<int64_t> llvmArgIndices =
+      record.getValueAsListOfInts("llvmArgIndices");
+  if (llvmArgIndices.empty())
+    append_range(llvmArgIndices, seq<int64_t>(0, op.getNumArgs()));
+  if (llvmArgIndices.size() != static_cast<size_t>(op.getNumArgs())) {
+    return emitError(
+        "'llvmArgIndices' does not match the number of arguments for op " +
+        op.getOperationName());
+  }
+
   // Progressively create the builder string by replacing $-variables. Keep only
   // the not-yet-traversed part of the builder pattern to avoid re-traversing
   // the string multiple times.
@@ -215,9 +229,13 @@ static LogicalResult emitOneMLIRBuilder(const Record &record, raw_ostream &os,
     // Then, rewrite the name based on its kind.
     FailureOr<int> argIndex = getArgumentIndex(op, name);
     if (succeeded(argIndex)) {
-      // Process the argument value assuming the MLIR and LLVM operand orders
-      // match and there are no optional or variadic arguments.
-      bs << formatv("processValue(llvmOperands[{0}])", *argIndex);
+      // Access the LLVM IR operand that maps to the given argument index using
+      // the provided argument indices mapping.
+      // FIXME: support trailing variadic arguments.
+      int64_t operandIdx = llvmArgIndices[*argIndex];
+      assert(operandIdx >= 0 && "expected argument to have a mapping");
+      assert(!isVariadicOperandName(op, name) && "unexpected variadic operand");
+      bs << formatv("processValue(llvmOperands[{0}])", operandIdx);
     } else if (isResultName(op, name)) {
       assert(op.getNumResults() == 1 &&
              "expected operation to have one result");


        


More information about the Mlir-commits mailing list