[Mlir-commits] [mlir] [mlir][llvm] Add icmp folder (PR #65343)

Tobias Gysi llvmlistbot at llvm.org
Tue Sep 5 09:03:04 PDT 2023


https://github.com/gysit created https://github.com/llvm/llvm-project/pull/65343:

This revision adds a simple icmp folder that performs the following folds to the LLVM dialect icmp op:
 - cmpi(eq/ne, x, x) -> true/false
 - cmpi(eq/ne, alloca, null) -> false/true
 - cmpi(eq/ne, null, alloca) -> cmpi(eq/ne, alloca, null)

>From b0997c85d3d3ec687ca2b1e97034a49fa5c4aab8 Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Tue, 5 Sep 2023 15:51:21 +0000
Subject: [PATCH] [mlir][llvm] Add icmp folder

This revision adds a simple icmp folder that performs the
following folds to the LLVM dialect icmp op:
 - cmpi(eq/ne, x, x) -> true/false
 - cmpi(eq/ne, alloca, null) -> false/true
 - cmpi(eq/ne, null, alloca) -> cmpi(eq/ne, alloca, null)
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td |  1 +
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp  | 38 ++++++++++++++++++++-
 mlir/test/Dialect/LLVMIR/canonicalize.mlir  | 29 ++++++++++++++++
 3 files changed, 67 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 2b4c8b609cfdd4f..2e09dc4a18786ad 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -139,6 +139,7 @@ def LLVM_ICmpOp : LLVM_ArithmeticCmpOp<"icmp", [Pure]> {
   // Set the $predicate index to -1 to indicate there is no matching operand
   // and decrement the following indices.
   list<int> llvmArgIndices = [-1, 0, 1];
+  let hasFolder = 1;
 }
 
 // Other floating-point operations.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index fd0d2b3fb3c1a08..9836c2b5e40a935 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -98,7 +98,7 @@ static Type getI1SameShape(Type type) {
 }
 
 //===----------------------------------------------------------------------===//
-// Printing, parsing and builder for LLVM::CmpOp.
+// Printing, parsing, folding and builder for LLVM::CmpOp.
 //===----------------------------------------------------------------------===//
 
 void ICmpOp::print(OpAsmPrinter &p) {
@@ -175,6 +175,42 @@ ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) {
   return parseCmpOp<FCmpPredicate>(parser, result);
 }
 
+/// Returns a scalar or vector boolean attribute of the given type.
+static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
+  auto boolAttr = BoolAttr::get(ctx, value);
+  ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
+  if (!shapedType)
+    return boolAttr;
+  return DenseElementsAttr::get(shapedType, boolAttr);
+}
+
+OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
+  if (getPredicate() != ICmpPredicate::eq &&
+      getPredicate() != ICmpPredicate::ne)
+    return {};
+
+  // cmpi(eq/ne, x, x) -> true/false
+  if (getLhs() == getRhs())
+    return getBoolAttribute(getType(), getContext(),
+                            getPredicate() == ICmpPredicate::eq);
+
+  // cmpi(eq/ne, alloca, null) -> false/true
+  if (getLhs().getDefiningOp<AllocaOp>() && getRhs().getDefiningOp<NullOp>())
+    return getBoolAttribute(getType(), getContext(),
+                            getPredicate() == ICmpPredicate::ne);
+
+  // cmpi(eq/ne, null, alloca) -> cmpi(eq/ne, alloca, null)
+  if (getLhs().getDefiningOp<NullOp>() && getRhs().getDefiningOp<AllocaOp>()) {
+    Value lhs = getLhs();
+    Value rhs = getRhs();
+    getLhsMutable().assign(rhs);
+    getRhsMutable().assign(lhs);
+    return getResult();
+  }
+
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // Printing, parsing and verification for LLVM::AllocaOp.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index 6b2cac14f29859f..c8f45e4d0e17138 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -1,5 +1,34 @@
 // RUN: mlir-opt --pass-pipeline='builtin.module(llvm.func(canonicalize{test-convergence}))' %s -split-input-file | FileCheck %s
 
+// CHECK-LABEL: @fold_icmp_eq
+llvm.func @fold_icmp_eq(%arg0 : i32) -> i1 {
+  // CHECK: %[[C0:.*]] = llvm.mlir.constant(true) : i1
+  %0 = llvm.icmp "eq" %arg0, %arg0 : i32
+  // CHECK: llvm.return %[[C0]]
+  llvm.return %0 : i1
+}
+
+// CHECK-LABEL: @fold_icmp_ne
+llvm.func @fold_icmp_ne(%arg0 : i32) -> i1 {
+  // CHECK: %[[C0:.*]] = llvm.mlir.constant(false) : i1
+  %0 = llvm.icmp "ne" %arg0, %arg0 : i32
+  // CHECK: llvm.return %[[C0]]
+  llvm.return %0 : i1
+}
+
+// CHECK-LABEL: @fold_icmp_alloca
+llvm.func @fold_icmp_alloca() -> i1 {
+  // CHECK: %[[C0:.*]] = llvm.mlir.constant(true) : i1
+  %c0 = llvm.mlir.null : !llvm.ptr
+  %c1 = arith.constant 1 : i64
+  %0 = llvm.alloca %c1 x i32 : (i64) -> !llvm.ptr
+  %1 = llvm.icmp "ne" %c0, %0 : !llvm.ptr
+  // CHECK: llvm.return %[[C0]]
+  llvm.return %1 : i1
+}
+
+// -----
+
 // CHECK-LABEL: fold_extractvalue
 llvm.func @fold_extractvalue() -> i32 {
   //  CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32



More information about the Mlir-commits mailing list