[Mlir-commits] [mlir] [MLIR][LLVM] Improve bit- and addrspacecast folders (PR #87745)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 4 23:27:59 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Christian Ulmann (Dinistro)

<details>
<summary>Changes</summary>

This commit extends the folders of chainable casts (bitcast and addrspacecast) to ensure that they fold a chain of the same casts into a single cast.

---
Full diff: https://github.com/llvm/llvm-project/pull/87745.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+18-15) 
- (modified) mlir/test/Dialect/LLVMIR/canonicalize.mlir (+24) 
- (modified) mlir/test/Dialect/LLVMIR/mem2reg.mlir (+2-3) 


``````````diff
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index e5c19a916392e1..f90240a67dcc5f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2761,17 +2761,27 @@ LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
 // Folder and verifier for LLVM::BitcastOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
-  // bitcast(x : T0, T0) -> x
-  if (getArg().getType() == getType())
-    return getArg();
-  // bitcast(bitcast(x : T0, T1), T0) -> x
-  if (auto prev = getArg().getDefiningOp<BitcastOp>())
-    if (prev.getArg().getType() == getType())
+/// Folds a cast op that can be chained.
+template <typename T>
+static Value foldChainableCast(T castOp, typename T::FoldAdaptor adaptor) {
+  // cast(x : T0, T0) -> x
+  if (castOp.getArg().getType() == castOp.getType())
+    return castOp.getArg();
+  if (auto prev = castOp.getArg().template getDefiningOp<T>()) {
+    // cast(cast(x : T0, T1), T0) -> x
+    if (prev.getArg().getType() == castOp.getType())
       return prev.getArg();
+    // cast(cast(x : T0, T1), T2) -> cast(x: T0, T2)
+    castOp.getArgMutable().set(prev.getArg());
+    return Value{castOp};
+  }
   return {};
 }
 
+OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
+  return foldChainableCast(*this, adaptor);
+}
+
 LogicalResult LLVM::BitcastOp::verify() {
   auto resultType = llvm::dyn_cast<LLVMPointerType>(
       extractVectorElementType(getResult().getType()));
@@ -2811,14 +2821,7 @@ LogicalResult LLVM::BitcastOp::verify() {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) {
-  // addrcast(x : T0, T0) -> x
-  if (getArg().getType() == getType())
-    return getArg();
-  // addrcast(addrcast(x : T0, T1), T0) -> x
-  if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>())
-    if (prev.getArg().getType() == getType())
-      return prev.getArg();
-  return {};
+  return foldChainableCast(*this, adaptor);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index 5e26fa37b681d7..fb653b1768ea39 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -101,6 +101,18 @@ llvm.func @fold_bitcast2(%x : i32) -> i32 {
 
 // -----
 
+// CHECK-LABEL: fold_bitcast_chain
+// CHECK-SAME: %[[a0:arg[0-9]+]]
+llvm.func @fold_bitcast_chain(%x : i32) -> vector<2xi16> {
+  %c = llvm.bitcast %x : i32 to f32
+  %d = llvm.bitcast %c : f32 to vector<2xi16>
+  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[a0]] : i32 to vector<2xi16>
+  // CHECK: llvm.return %[[BITCAST]]
+  llvm.return %d : vector<2xi16>
+}
+
+// -----
+
 // CHECK-LABEL: fold_addrcast
 // CHECK-SAME: %[[a0:arg[0-9]+]]
 // CHECK-NEXT: llvm.return %[[a0]]
@@ -120,6 +132,18 @@ llvm.func @fold_addrcast2(%x : !llvm.ptr) -> !llvm.ptr {
 
 // -----
 
+// CHECK-LABEL: fold_addrcast_chain
+// CHECK-SAME: %[[a0:arg[0-9]+]]
+llvm.func @fold_addrcast_chain(%x : !llvm.ptr) -> !llvm.ptr<2> {
+  %c = llvm.addrspacecast %x : !llvm.ptr to !llvm.ptr<1>
+  %d = llvm.addrspacecast %c : !llvm.ptr<1> to !llvm.ptr<2>
+  // CHECK: %[[ADDRCAST:.*]] = llvm.addrspacecast %[[a0]] : !llvm.ptr to !llvm.ptr<2>
+  // CHECK: llvm.return %[[ADDRCAST]]
+  llvm.return %d : !llvm.ptr<2>
+}
+
+// -----
+
 // CHECK-LABEL: fold_gep
 // CHECK-SAME: %[[a0:arg[0-9]+]]
 // CHECK-NEXT: llvm.return %[[a0]]
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 61a3d933ee1510..fa5d842302d0f4 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -793,9 +793,8 @@ llvm.func @store_int_to_vector(%arg: i32) -> vector<4xi8> {
   %1 = llvm.alloca %0 x vector<2xi16> {alignment = 4 : i64} : (i32) -> !llvm.ptr
   llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
   %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> vector<4xi8>
-  // CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<2xi16>
-  // CHECK: %[[BITCAST1:.*]] = llvm.bitcast %[[BITCAST0]] : vector<2xi16> to vector<4xi8>
-  // CHECK: llvm.return %[[BITCAST1]]
+  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<4xi8>
+  // CHECK: llvm.return %[[BITCAST]]
   llvm.return %2 : vector<4xi8>
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/87745


More information about the Mlir-commits mailing list