[Mlir-commits] [mlir] [MLIR][LLVM] Improve bit- and addrspacecast folders (PR #87745)

Christian Ulmann llvmlistbot at llvm.org
Thu Apr 4 23:49:11 PDT 2024


https://github.com/Dinistro updated https://github.com/llvm/llvm-project/pull/87745

>From ba4ee3cbc516aefb9e30d558bd69493fed3c5e56 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Fri, 5 Apr 2024 06:19:55 +0000
Subject: [PATCH 1/2] [MLIR][LLVM] Improve bit- and addrspacecast folders

This commit extends the folders of chainable casts (bitcast and
addrspacecast) to ensure that they fold a chain of the same casts into
a single cast.
---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 33 ++++++++++++----------
 mlir/test/Dialect/LLVMIR/canonicalize.mlir | 24 ++++++++++++++++
 mlir/test/Dialect/LLVMIR/mem2reg.mlir      |  5 ++--
 3 files changed, 44 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index e5c19a916392e1..f90240a67dcc5f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2761,17 +2761,27 @@ LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
 // Folder and verifier for LLVM::BitcastOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
-  // bitcast(x : T0, T0) -> x
-  if (getArg().getType() == getType())
-    return getArg();
-  // bitcast(bitcast(x : T0, T1), T0) -> x
-  if (auto prev = getArg().getDefiningOp<BitcastOp>())
-    if (prev.getArg().getType() == getType())
+/// Folds a cast op that can be chained.
+template <typename T>
+static Value foldChainableCast(T castOp, typename T::FoldAdaptor adaptor) {
+  // cast(x : T0, T0) -> x
+  if (castOp.getArg().getType() == castOp.getType())
+    return castOp.getArg();
+  if (auto prev = castOp.getArg().template getDefiningOp<T>()) {
+    // cast(cast(x : T0, T1), T0) -> x
+    if (prev.getArg().getType() == castOp.getType())
       return prev.getArg();
+    // cast(cast(x : T0, T1), T2) -> cast(x: T0, T2)
+    castOp.getArgMutable().set(prev.getArg());
+    return Value{castOp};
+  }
   return {};
 }
 
+OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
+  return foldChainableCast(*this, adaptor);
+}
+
 LogicalResult LLVM::BitcastOp::verify() {
   auto resultType = llvm::dyn_cast<LLVMPointerType>(
       extractVectorElementType(getResult().getType()));
@@ -2811,14 +2821,7 @@ LogicalResult LLVM::BitcastOp::verify() {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) {
-  // addrcast(x : T0, T0) -> x
-  if (getArg().getType() == getType())
-    return getArg();
-  // addrcast(addrcast(x : T0, T1), T0) -> x
-  if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>())
-    if (prev.getArg().getType() == getType())
-      return prev.getArg();
-  return {};
+  return foldChainableCast(*this, adaptor);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index 5e26fa37b681d7..fb653b1768ea39 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -101,6 +101,18 @@ llvm.func @fold_bitcast2(%x : i32) -> i32 {
 
 // -----
 
+// CHECK-LABEL: fold_bitcast_chain
+// CHECK-SAME: %[[a0:arg[0-9]+]]
+llvm.func @fold_bitcast_chain(%x : i32) -> vector<2xi16> {
+  %c = llvm.bitcast %x : i32 to f32
+  %d = llvm.bitcast %c : f32 to vector<2xi16>
+  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[a0]] : i32 to vector<2xi16>
+  // CHECK: llvm.return %[[BITCAST]]
+  llvm.return %d : vector<2xi16>
+}
+
+// -----
+
 // CHECK-LABEL: fold_addrcast
 // CHECK-SAME: %[[a0:arg[0-9]+]]
 // CHECK-NEXT: llvm.return %[[a0]]
@@ -120,6 +132,18 @@ llvm.func @fold_addrcast2(%x : !llvm.ptr) -> !llvm.ptr {
 
 // -----
 
+// CHECK-LABEL: fold_addrcast_chain
+// CHECK-SAME: %[[a0:arg[0-9]+]]
+llvm.func @fold_addrcast_chain(%x : !llvm.ptr) -> !llvm.ptr<2> {
+  %c = llvm.addrspacecast %x : !llvm.ptr to !llvm.ptr<1>
+  %d = llvm.addrspacecast %c : !llvm.ptr<1> to !llvm.ptr<2>
+  // CHECK: %[[ADDRCAST:.*]] = llvm.addrspacecast %[[a0]] : !llvm.ptr to !llvm.ptr<2>
+  // CHECK: llvm.return %[[ADDRCAST]]
+  llvm.return %d : !llvm.ptr<2>
+}
+
+// -----
+
 // CHECK-LABEL: fold_gep
 // CHECK-SAME: %[[a0:arg[0-9]+]]
 // CHECK-NEXT: llvm.return %[[a0]]
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 61a3d933ee1510..fa5d842302d0f4 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -793,9 +793,8 @@ llvm.func @store_int_to_vector(%arg: i32) -> vector<4xi8> {
   %1 = llvm.alloca %0 x vector<2xi16> {alignment = 4 : i64} : (i32) -> !llvm.ptr
   llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
   %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> vector<4xi8>
-  // CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<2xi16>
-  // CHECK: %[[BITCAST1:.*]] = llvm.bitcast %[[BITCAST0]] : vector<2xi16> to vector<4xi8>
-  // CHECK: llvm.return %[[BITCAST1]]
+  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<4xi8>
+  // CHECK: llvm.return %[[BITCAST]]
   llvm.return %2 : vector<4xi8>
 }
 

>From b2da121e1d841ccafddac9f6184b68c12670883c Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Fri, 5 Apr 2024 06:48:58 +0000
Subject: [PATCH 2/2] cleanup the canonicalization test file

---
 mlir/test/Dialect/LLVMIR/canonicalize.mlir | 52 ++++++++++++++--------
 1 file changed, 34 insertions(+), 18 deletions(-)

diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index fb653b1768ea39..6b265bbbdbfb2d 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -8,6 +8,8 @@ llvm.func @fold_icmp_eq(%arg0 : i32) -> i1 {
   llvm.return %0 : i1
 }
 
+// -----
+
 // CHECK-LABEL: @fold_icmp_ne
 llvm.func @fold_icmp_ne(%arg0 : vector<2xi32>) -> vector<2xi1> {
   // CHECK: %[[C0:.*]] = llvm.mlir.constant(dense<false> : vector<2xi1>) : vector<2xi1>
@@ -16,6 +18,8 @@ llvm.func @fold_icmp_ne(%arg0 : vector<2xi32>) -> vector<2xi1> {
   llvm.return %0 : vector<2xi1>
 }
 
+// -----
+
 // CHECK-LABEL: @fold_icmp_alloca
 llvm.func @fold_icmp_alloca() -> i1 {
   // CHECK: %[[C0:.*]] = llvm.mlir.constant(true) : i1
@@ -83,16 +87,18 @@ llvm.func @fold_unrelated_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
 // -----
 
 // CHECK-LABEL: fold_bitcast
-// CHECK-SAME: %[[a0:arg[0-9]+]]
-// CHECK-NEXT: llvm.return %[[a0]]
+// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
+// CHECK-NEXT: llvm.return %[[ARG]]
 llvm.func @fold_bitcast(%x : !llvm.ptr) -> !llvm.ptr {
   %c = llvm.bitcast %x : !llvm.ptr to !llvm.ptr
   llvm.return %c : !llvm.ptr
 }
 
+// -----
+
 // CHECK-LABEL: fold_bitcast2
-// CHECK-SAME: %[[a0:arg[0-9]+]]
-// CHECK-NEXT: llvm.return %[[a0]]
+// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
+// CHECK-NEXT: llvm.return %[[ARG]]
 llvm.func @fold_bitcast2(%x : i32) -> i32 {
   %c = llvm.bitcast %x : i32 to f32
   %d = llvm.bitcast %c : f32 to i32
@@ -102,11 +108,11 @@ llvm.func @fold_bitcast2(%x : i32) -> i32 {
 // -----
 
 // CHECK-LABEL: fold_bitcast_chain
-// CHECK-SAME: %[[a0:arg[0-9]+]]
+// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
 llvm.func @fold_bitcast_chain(%x : i32) -> vector<2xi16> {
   %c = llvm.bitcast %x : i32 to f32
   %d = llvm.bitcast %c : f32 to vector<2xi16>
-  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[a0]] : i32 to vector<2xi16>
+  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<2xi16>
   // CHECK: llvm.return %[[BITCAST]]
   llvm.return %d : vector<2xi16>
 }
@@ -114,16 +120,18 @@ llvm.func @fold_bitcast_chain(%x : i32) -> vector<2xi16> {
 // -----
 
 // CHECK-LABEL: fold_addrcast
-// CHECK-SAME: %[[a0:arg[0-9]+]]
-// CHECK-NEXT: llvm.return %[[a0]]
+// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
+// CHECK-NEXT: llvm.return %[[ARG]]
 llvm.func @fold_addrcast(%x : !llvm.ptr) -> !llvm.ptr {
   %c = llvm.addrspacecast %x : !llvm.ptr to !llvm.ptr
   llvm.return %c : !llvm.ptr
 }
 
+// -----
+
 // CHECK-LABEL: fold_addrcast2
-// CHECK-SAME: %[[a0:arg[0-9]+]]
-// CHECK-NEXT: llvm.return %[[a0]]
+// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
+// CHECK-NEXT: llvm.return %[[ARG]]
 llvm.func @fold_addrcast2(%x : !llvm.ptr) -> !llvm.ptr {
   %c = llvm.addrspacecast %x : !llvm.ptr to !llvm.ptr<5>
   %d = llvm.addrspacecast %c : !llvm.ptr<5> to !llvm.ptr
@@ -133,11 +141,11 @@ llvm.func @fold_addrcast2(%x : !llvm.ptr) -> !llvm.ptr {
 // -----
 
 // CHECK-LABEL: fold_addrcast_chain
-// CHECK-SAME: %[[a0:arg[0-9]+]]
+// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
 llvm.func @fold_addrcast_chain(%x : !llvm.ptr) -> !llvm.ptr<2> {
   %c = llvm.addrspacecast %x : !llvm.ptr to !llvm.ptr<1>
   %d = llvm.addrspacecast %c : !llvm.ptr<1> to !llvm.ptr<2>
-  // CHECK: %[[ADDRCAST:.*]] = llvm.addrspacecast %[[a0]] : !llvm.ptr to !llvm.ptr<2>
+  // CHECK: %[[ADDRCAST:.*]] = llvm.addrspacecast %[[ARG]] : !llvm.ptr to !llvm.ptr<2>
   // CHECK: llvm.return %[[ADDRCAST]]
   llvm.return %d : !llvm.ptr<2>
 }
@@ -145,17 +153,19 @@ llvm.func @fold_addrcast_chain(%x : !llvm.ptr) -> !llvm.ptr<2> {
 // -----
 
 // CHECK-LABEL: fold_gep
-// CHECK-SAME: %[[a0:arg[0-9]+]]
-// CHECK-NEXT: llvm.return %[[a0]]
+// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
+// CHECK-NEXT: llvm.return %[[ARG]]
 llvm.func @fold_gep(%x : !llvm.ptr) -> !llvm.ptr {
   %c0 = arith.constant 0 : i32
   %c = llvm.getelementptr %x[%c0] : (!llvm.ptr, i32) -> !llvm.ptr, i8
   llvm.return %c : !llvm.ptr
 }
 
+// -----
+
 // CHECK-LABEL: fold_gep_neg
-// CHECK-SAME: %[[a0:arg[0-9]+]]
-// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr inbounds %[[a0]][0, 1]
+// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
+// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr inbounds %[[ARG]][0, 1]
 // CHECK-NEXT: llvm.return %[[RES]]
 llvm.func @fold_gep_neg(%x : !llvm.ptr) -> !llvm.ptr {
   %c0 = arith.constant 0 : i32
@@ -163,9 +173,11 @@ llvm.func @fold_gep_neg(%x : !llvm.ptr) -> !llvm.ptr {
   llvm.return %0 : !llvm.ptr
 }
 
+// -----
+
 // CHECK-LABEL: fold_gep_canon
-// CHECK-SAME: %[[a0:arg[0-9]+]]
-// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr %[[a0]][2]
+// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
+// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr %[[ARG]][2]
 // CHECK-NEXT: llvm.return %[[RES]]
 llvm.func @fold_gep_canon(%x : !llvm.ptr) -> !llvm.ptr {
   %c2 = arith.constant 2 : i32
@@ -199,6 +211,8 @@ llvm.func @load_dce(%x : !llvm.ptr) {
   llvm.return
 }
 
+// -----
+
 llvm.mlir.global external @fp() : !llvm.ptr
 
 // CHECK-LABEL: addr_dce
@@ -208,6 +222,8 @@ llvm.func @addr_dce(%x : !llvm.ptr) {
   llvm.return
 }
 
+// -----
+
 // CHECK-LABEL: alloca_dce
 // CHECK-NEXT: llvm.return
 llvm.func @alloca_dce() {



More information about the Mlir-commits mailing list