[Mlir-commits] [mlir] [MLIR][LLVM] Improve bit- and addrspacecast folders (PR #87745)
Christian Ulmann
llvmlistbot at llvm.org
Thu Apr 4 23:27:31 PDT 2024
https://github.com/Dinistro created https://github.com/llvm/llvm-project/pull/87745
This commit extends the folders of chainable casts (bitcast and addrspacecast) to ensure that they fold a chain of the same casts into a single cast.
>From ba4ee3cbc516aefb9e30d558bd69493fed3c5e56 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Fri, 5 Apr 2024 06:19:55 +0000
Subject: [PATCH] [MLIR][LLVM] Improve bit- and addrspacecast folders
This commit extends the folders of chainable casts (bitcast and
addrspacecast) to ensure that they fold a chain of the same casts into
a single cast.
---
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 33 ++++++++++++----------
mlir/test/Dialect/LLVMIR/canonicalize.mlir | 24 ++++++++++++++++
mlir/test/Dialect/LLVMIR/mem2reg.mlir | 5 ++--
3 files changed, 44 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index e5c19a916392e1..f90240a67dcc5f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2761,17 +2761,27 @@ LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
// Folder and verifier for LLVM::BitcastOp
//===----------------------------------------------------------------------===//
-OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
- // bitcast(x : T0, T0) -> x
- if (getArg().getType() == getType())
- return getArg();
- // bitcast(bitcast(x : T0, T1), T0) -> x
- if (auto prev = getArg().getDefiningOp<BitcastOp>())
- if (prev.getArg().getType() == getType())
+/// Folds a cast op that can be chained.
+template <typename T>
+static Value foldChainableCast(T castOp, typename T::FoldAdaptor adaptor) {
+ // cast(x : T0, T0) -> x
+ if (castOp.getArg().getType() == castOp.getType())
+ return castOp.getArg();
+ if (auto prev = castOp.getArg().template getDefiningOp<T>()) {
+ // cast(cast(x : T0, T1), T0) -> x
+ if (prev.getArg().getType() == castOp.getType())
return prev.getArg();
+ // cast(cast(x : T0, T1), T2) -> cast(x: T0, T2)
+ castOp.getArgMutable().set(prev.getArg());
+ return Value{castOp};
+ }
return {};
}
+OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
+ return foldChainableCast(*this, adaptor);
+}
+
LogicalResult LLVM::BitcastOp::verify() {
auto resultType = llvm::dyn_cast<LLVMPointerType>(
extractVectorElementType(getResult().getType()));
@@ -2811,14 +2821,7 @@ LogicalResult LLVM::BitcastOp::verify() {
//===----------------------------------------------------------------------===//
OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) {
- // addrcast(x : T0, T0) -> x
- if (getArg().getType() == getType())
- return getArg();
- // addrcast(addrcast(x : T0, T1), T0) -> x
- if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>())
- if (prev.getArg().getType() == getType())
- return prev.getArg();
- return {};
+ return foldChainableCast(*this, adaptor);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index 5e26fa37b681d7..fb653b1768ea39 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -101,6 +101,18 @@ llvm.func @fold_bitcast2(%x : i32) -> i32 {
// -----
+// CHECK-LABEL: fold_bitcast_chain
+// CHECK-SAME: %[[a0:arg[0-9]+]]
+llvm.func @fold_bitcast_chain(%x : i32) -> vector<2xi16> {
+ %c = llvm.bitcast %x : i32 to f32
+ %d = llvm.bitcast %c : f32 to vector<2xi16>
+ // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[a0]] : i32 to vector<2xi16>
+ // CHECK: llvm.return %[[BITCAST]]
+ llvm.return %d : vector<2xi16>
+}
+
+// -----
+
// CHECK-LABEL: fold_addrcast
// CHECK-SAME: %[[a0:arg[0-9]+]]
// CHECK-NEXT: llvm.return %[[a0]]
@@ -120,6 +132,18 @@ llvm.func @fold_addrcast2(%x : !llvm.ptr) -> !llvm.ptr {
// -----
+// CHECK-LABEL: fold_addrcast_chain
+// CHECK-SAME: %[[a0:arg[0-9]+]]
+llvm.func @fold_addrcast_chain(%x : !llvm.ptr) -> !llvm.ptr<2> {
+ %c = llvm.addrspacecast %x : !llvm.ptr to !llvm.ptr<1>
+ %d = llvm.addrspacecast %c : !llvm.ptr<1> to !llvm.ptr<2>
+ // CHECK: %[[ADDRCAST:.*]] = llvm.addrspacecast %[[a0]] : !llvm.ptr to !llvm.ptr<2>
+ // CHECK: llvm.return %[[ADDRCAST]]
+ llvm.return %d : !llvm.ptr<2>
+}
+
+// -----
+
// CHECK-LABEL: fold_gep
// CHECK-SAME: %[[a0:arg[0-9]+]]
// CHECK-NEXT: llvm.return %[[a0]]
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 61a3d933ee1510..fa5d842302d0f4 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -793,9 +793,8 @@ llvm.func @store_int_to_vector(%arg: i32) -> vector<4xi8> {
%1 = llvm.alloca %0 x vector<2xi16> {alignment = 4 : i64} : (i32) -> !llvm.ptr
llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> vector<4xi8>
- // CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<2xi16>
- // CHECK: %[[BITCAST1:.*]] = llvm.bitcast %[[BITCAST0]] : vector<2xi16> to vector<4xi8>
- // CHECK: llvm.return %[[BITCAST1]]
+ // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<4xi8>
+ // CHECK: llvm.return %[[BITCAST]]
llvm.return %2 : vector<4xi8>
}
More information about the Mlir-commits
mailing list