[Mlir-commits] [mlir] cd8978e - [mlir][LLVMIR] Ask ICmpOp to return vector<Nxi1> when needed
Min-Yih Hsu
llvmlistbot at llvm.org
Wed Jun 15 14:34:00 PDT 2022
Author: Min-Yih Hsu
Date: 2022-06-15T14:33:48-07:00
New Revision: cd8978e19ed90ddd695a193525d50319e74ff507
URL: https://github.com/llvm/llvm-project/commit/cd8978e19ed90ddd695a193525d50319e74ff507
DIFF: https://github.com/llvm/llvm-project/commit/cd8978e19ed90ddd695a193525d50319e74ff507.diff
LOG: [mlir][LLVMIR] Ask ICmpOp to return vector<Nxi1> when needed
If any of the operands for ICmpOp is a vector, returns a vector<Nxi1>
, rather than an i1 type result.
Differential Revision: https://reviews.llvm.org/D127536
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Dialect/LLVMIR/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 6f4da42908fb0..f34812d45fb14 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -286,11 +286,8 @@ def LLVM_ICmpOp : LLVM_Op<"icmp", [NoSideEffect]> {
$res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
}];
let builders = [
- OpBuilder<(ins "ICmpPredicate":$predicate, "Value":$lhs, "Value":$rhs),
- [{
- build($_builder, $_state, IntegerType::get(lhs.getType().getContext(), 1),
- predicate, lhs, rhs);
- }]>];
+ OpBuilder<(ins "ICmpPredicate":$predicate, "Value":$lhs, "Value":$rhs)>
+ ];
let hasCustomAssemblyFormat = 1;
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index b972c1fda3dbb..6cfe09aaa82ef 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -90,8 +90,28 @@ static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
}
//===----------------------------------------------------------------------===//
-// Printing/parsing for LLVM::CmpOp.
-//===----------------------------------------------------------------------===//
+// Printing, parsing and builder for LLVM::CmpOp.
+//===----------------------------------------------------------------------===//
+
+void ICmpOp::build(OpBuilder &builder, OperationState &result,
+ ICmpPredicate predicate, Value lhs, Value rhs) {
+ auto boolType = IntegerType::get(lhs.getType().getContext(), 1);
+ if (LLVM::isCompatibleVectorType(lhs.getType()) ||
+ LLVM::isCompatibleVectorType(rhs.getType())) {
+ int64_t numLHSElements = 1, numRHSElements = 1;
+ if (LLVM::isCompatibleVectorType(lhs.getType()))
+ numLHSElements =
+ LLVM::getVectorNumElements(lhs.getType()).getFixedValue();
+ if (LLVM::isCompatibleVectorType(rhs.getType()))
+ numRHSElements =
+ LLVM::getVectorNumElements(rhs.getType()).getFixedValue();
+ build(builder, result,
+ VectorType::get({std::max(numLHSElements, numRHSElements)}, boolType),
+ predicate, lhs, rhs);
+ } else {
+ build(builder, result, boolType, predicate, lhs, rhs);
+ }
+}
void ICmpOp::print(OpAsmPrinter &p) {
p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0)
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index f18972caaf0a0..50a50af8eb7cf 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -14,9 +14,12 @@ func.func @ops(%arg0: i32, %arg1: f32,
// CHECK: {{.*}} = llvm.sdiv %[[I32]], %[[I32]] : i32
// CHECK: {{.*}} = llvm.urem %[[I32]], %[[I32]] : i32
// CHECK: {{.*}} = llvm.srem %[[I32]], %[[I32]] : i32
-// CHECK: {{.*}} = llvm.icmp "ne" %[[I32]], %[[I32]] : i32
-// CHECK: {{.*}} = llvm.icmp "ne" %[[I8PTR1]], %[[I8PTR1]] : !llvm.ptr<i8>
-// CHECK: {{.*}} = llvm.icmp "ne" %[[VI8PTR1]], %[[VI8PTR1]] : !llvm.vec<2 x ptr<i8>>
+// CHECK: %[[SCALAR_PRED0:.+]] = llvm.icmp "ne" %[[I32]], %[[I32]] : i32
+// CHECK: {{.*}} = llvm.add %[[SCALAR_PRED0]], %[[SCALAR_PRED0]] : i1
+// CHECK: %[[SCALAR_PRED1:.+]] = llvm.icmp "ne" %[[I8PTR1]], %[[I8PTR1]] : !llvm.ptr<i8>
+// CHECK: {{.*}} = llvm.add %[[SCALAR_PRED1]], %[[SCALAR_PRED1]] : i1
+// CHECK: %[[VEC_PRED:.+]] = llvm.icmp "ne" %[[VI8PTR1]], %[[VI8PTR1]] : !llvm.vec<2 x ptr<i8>>
+// CHECK: {{.*}} = llvm.add %[[VEC_PRED]], %[[VEC_PRED]] : vector<2xi1>
%0 = llvm.add %arg0, %arg0 : i32
%1 = llvm.sub %arg0, %arg0 : i32
%2 = llvm.mul %arg0, %arg0 : i32
@@ -25,8 +28,11 @@ func.func @ops(%arg0: i32, %arg1: f32,
%5 = llvm.urem %arg0, %arg0 : i32
%6 = llvm.srem %arg0, %arg0 : i32
%7 = llvm.icmp "ne" %arg0, %arg0 : i32
+ %typecheck_7 = llvm.add %7, %7 : i1
%ptrcmp = llvm.icmp "ne" %arg2, %arg2 : !llvm.ptr<i8>
+ %typecheck_ptrcmp = llvm.add %ptrcmp, %ptrcmp : i1
%vptrcmp = llvm.icmp "ne" %arg5, %arg5 : !llvm.vec<2 x ptr<i8>>
+ %typecheck_vptrcmp = llvm.add %vptrcmp, %vptrcmp : vector<2 x i1>
// Floating point binary operations.
//
More information about the Mlir-commits
mailing list