[Mlir-commits] [mlir] cd8978e - [mlir][LLVMIR] Ask ICmpOp to return vector<Nxi1> when needed

Min-Yih Hsu llvmlistbot at llvm.org
Wed Jun 15 14:34:00 PDT 2022


Author: Min-Yih Hsu
Date: 2022-06-15T14:33:48-07:00
New Revision: cd8978e19ed90ddd695a193525d50319e74ff507

URL: https://github.com/llvm/llvm-project/commit/cd8978e19ed90ddd695a193525d50319e74ff507
DIFF: https://github.com/llvm/llvm-project/commit/cd8978e19ed90ddd695a193525d50319e74ff507.diff

LOG: [mlir][LLVMIR] Ask ICmpOp to return vector<Nxi1> when needed

If any of the operands for ICmpOp is a vector, returns a vector<Nxi1>
, rather than an i1 type result.

Differential Revision: https://reviews.llvm.org/D127536

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/test/Dialect/LLVMIR/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 6f4da42908fb0..f34812d45fb14 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -286,11 +286,8 @@ def LLVM_ICmpOp : LLVM_Op<"icmp", [NoSideEffect]> {
     $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
   }];
   let builders = [
-    OpBuilder<(ins "ICmpPredicate":$predicate, "Value":$lhs, "Value":$rhs),
-    [{
-      build($_builder, $_state, IntegerType::get(lhs.getType().getContext(), 1),
-            predicate, lhs, rhs);
-    }]>];
+    OpBuilder<(ins "ICmpPredicate":$predicate, "Value":$lhs, "Value":$rhs)>
+  ];
   let hasCustomAssemblyFormat = 1;
 }
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index b972c1fda3dbb..6cfe09aaa82ef 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -90,8 +90,28 @@ static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
 }
 
 //===----------------------------------------------------------------------===//
-// Printing/parsing for LLVM::CmpOp.
-//===----------------------------------------------------------------------===//
+// Printing, parsing and builder for LLVM::CmpOp.
+//===----------------------------------------------------------------------===//
+
+void ICmpOp::build(OpBuilder &builder, OperationState &result,
+                   ICmpPredicate predicate, Value lhs, Value rhs) {
+  auto boolType = IntegerType::get(lhs.getType().getContext(), 1);
+  if (LLVM::isCompatibleVectorType(lhs.getType()) ||
+      LLVM::isCompatibleVectorType(rhs.getType())) {
+    int64_t numLHSElements = 1, numRHSElements = 1;
+    if (LLVM::isCompatibleVectorType(lhs.getType()))
+      numLHSElements =
+          LLVM::getVectorNumElements(lhs.getType()).getFixedValue();
+    if (LLVM::isCompatibleVectorType(rhs.getType()))
+      numRHSElements =
+          LLVM::getVectorNumElements(rhs.getType()).getFixedValue();
+    build(builder, result,
+          VectorType::get({std::max(numLHSElements, numRHSElements)}, boolType),
+          predicate, lhs, rhs);
+  } else {
+    build(builder, result, boolType, predicate, lhs, rhs);
+  }
+}
 
 void ICmpOp::print(OpAsmPrinter &p) {
   p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0)

diff  --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index f18972caaf0a0..50a50af8eb7cf 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -14,9 +14,12 @@ func.func @ops(%arg0: i32, %arg1: f32,
 // CHECK: {{.*}} = llvm.sdiv %[[I32]], %[[I32]] : i32
 // CHECK: {{.*}} = llvm.urem %[[I32]], %[[I32]] : i32
 // CHECK: {{.*}} = llvm.srem %[[I32]], %[[I32]] : i32
-// CHECK: {{.*}} = llvm.icmp "ne" %[[I32]], %[[I32]] : i32
-// CHECK: {{.*}} = llvm.icmp "ne" %[[I8PTR1]], %[[I8PTR1]] : !llvm.ptr<i8>
-// CHECK: {{.*}} = llvm.icmp "ne" %[[VI8PTR1]], %[[VI8PTR1]] : !llvm.vec<2 x ptr<i8>>
+// CHECK: %[[SCALAR_PRED0:.+]] = llvm.icmp "ne" %[[I32]], %[[I32]] : i32
+// CHECK: {{.*}} = llvm.add %[[SCALAR_PRED0]], %[[SCALAR_PRED0]] : i1
+// CHECK: %[[SCALAR_PRED1:.+]] = llvm.icmp "ne" %[[I8PTR1]], %[[I8PTR1]] : !llvm.ptr<i8>
+// CHECK: {{.*}} = llvm.add %[[SCALAR_PRED1]], %[[SCALAR_PRED1]] : i1
+// CHECK: %[[VEC_PRED:.+]] = llvm.icmp "ne" %[[VI8PTR1]], %[[VI8PTR1]] : !llvm.vec<2 x ptr<i8>>
+// CHECK: {{.*}} = llvm.add %[[VEC_PRED]], %[[VEC_PRED]] : vector<2xi1>
   %0 = llvm.add %arg0, %arg0 : i32
   %1 = llvm.sub %arg0, %arg0 : i32
   %2 = llvm.mul %arg0, %arg0 : i32
@@ -25,8 +28,11 @@ func.func @ops(%arg0: i32, %arg1: f32,
   %5 = llvm.urem %arg0, %arg0 : i32
   %6 = llvm.srem %arg0, %arg0 : i32
   %7 = llvm.icmp "ne" %arg0, %arg0 : i32
+  %typecheck_7 = llvm.add %7, %7 : i1
   %ptrcmp = llvm.icmp "ne" %arg2, %arg2 : !llvm.ptr<i8>
+  %typecheck_ptrcmp = llvm.add %ptrcmp, %ptrcmp : i1
   %vptrcmp = llvm.icmp "ne" %arg5, %arg5 : !llvm.vec<2 x ptr<i8>>
+  %typecheck_vptrcmp = llvm.add %vptrcmp, %vptrcmp : vector<2 x i1>
 
 // Floating point binary operations.
 //


        


More information about the Mlir-commits mailing list