[Mlir-commits] [mlir] ef8322f - [MLIR][LLVM] Improve bit- and addrspacecast folders (#87745)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 5 00:14:16 PDT 2024
Author: Christian Ulmann
Date: 2024-04-05T09:14:13+02:00
New Revision: ef8322f41d542bcf05dbe6079339d707534efbed
URL: https://github.com/llvm/llvm-project/commit/ef8322f41d542bcf05dbe6079339d707534efbed
DIFF: https://github.com/llvm/llvm-project/commit/ef8322f41d542bcf05dbe6079339d707534efbed.diff
LOG: [MLIR][LLVM] Improve bit- and addrspacecast folders (#87745)
This commit extends the folders of chainable casts (bitcast and
addrspacecast) to ensure that they fold a chain of the same casts into a
single cast.
Additionally cleans up the canonicalization test file, as this used some
outdated constructs.
Added:
Modified:
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Dialect/LLVMIR/canonicalize.mlir
mlir/test/Dialect/LLVMIR/mem2reg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index e5c19a916392e1..f90240a67dcc5f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2761,17 +2761,27 @@ LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
// Folder and verifier for LLVM::BitcastOp
//===----------------------------------------------------------------------===//
-OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
- // bitcast(x : T0, T0) -> x
- if (getArg().getType() == getType())
- return getArg();
- // bitcast(bitcast(x : T0, T1), T0) -> x
- if (auto prev = getArg().getDefiningOp<BitcastOp>())
- if (prev.getArg().getType() == getType())
+/// Folds a cast op that can be chained.
+template <typename T>
+static Value foldChainableCast(T castOp, typename T::FoldAdaptor adaptor) {
+ // cast(x : T0, T0) -> x
+ if (castOp.getArg().getType() == castOp.getType())
+ return castOp.getArg();
+ if (auto prev = castOp.getArg().template getDefiningOp<T>()) {
+ // cast(cast(x : T0, T1), T0) -> x
+ if (prev.getArg().getType() == castOp.getType())
return prev.getArg();
+ // cast(cast(x : T0, T1), T2) -> cast(x: T0, T2)
+ castOp.getArgMutable().set(prev.getArg());
+ return Value{castOp};
+ }
return {};
}
+OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
+ return foldChainableCast(*this, adaptor);
+}
+
LogicalResult LLVM::BitcastOp::verify() {
auto resultType = llvm::dyn_cast<LLVMPointerType>(
extractVectorElementType(getResult().getType()));
@@ -2811,14 +2821,7 @@ LogicalResult LLVM::BitcastOp::verify() {
//===----------------------------------------------------------------------===//
OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) {
- // addrcast(x : T0, T0) -> x
- if (getArg().getType() == getType())
- return getArg();
- // addrcast(addrcast(x : T0, T1), T0) -> x
- if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>())
- if (prev.getArg().getType() == getType())
- return prev.getArg();
- return {};
+ return foldChainableCast(*this, adaptor);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index 5e26fa37b681d7..6b265bbbdbfb2d 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -8,6 +8,8 @@ llvm.func @fold_icmp_eq(%arg0 : i32) -> i1 {
llvm.return %0 : i1
}
+// -----
+
// CHECK-LABEL: @fold_icmp_ne
llvm.func @fold_icmp_ne(%arg0 : vector<2xi32>) -> vector<2xi1> {
// CHECK: %[[C0:.*]] = llvm.mlir.constant(dense<false> : vector<2xi1>) : vector<2xi1>
@@ -16,6 +18,8 @@ llvm.func @fold_icmp_ne(%arg0 : vector<2xi32>) -> vector<2xi1> {
llvm.return %0 : vector<2xi1>
}
+// -----
+
// CHECK-LABEL: @fold_icmp_alloca
llvm.func @fold_icmp_alloca() -> i1 {
// CHECK: %[[C0:.*]] = llvm.mlir.constant(true) : i1
@@ -83,16 +87,18 @@ llvm.func @fold_unrelated_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
// -----
// CHECK-LABEL: fold_bitcast
-// CHECK-SAME: %[[a0:arg[0-9]+]]
-// CHECK-NEXT: llvm.return %[[a0]]
+// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
+// CHECK-NEXT: llvm.return %[[ARG]]
llvm.func @fold_bitcast(%x : !llvm.ptr) -> !llvm.ptr {
%c = llvm.bitcast %x : !llvm.ptr to !llvm.ptr
llvm.return %c : !llvm.ptr
}
+// -----
+
// CHECK-LABEL: fold_bitcast2
-// CHECK-SAME: %[[a0:arg[0-9]+]]
-// CHECK-NEXT: llvm.return %[[a0]]
+// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
+// CHECK-NEXT: llvm.return %[[ARG]]
llvm.func @fold_bitcast2(%x : i32) -> i32 {
%c = llvm.bitcast %x : i32 to f32
%d = llvm.bitcast %c : f32 to i32
@@ -101,17 +107,31 @@ llvm.func @fold_bitcast2(%x : i32) -> i32 {
// -----
+// CHECK-LABEL: fold_bitcast_chain
+// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
+llvm.func @fold_bitcast_chain(%x : i32) -> vector<2xi16> {
+ %c = llvm.bitcast %x : i32 to f32
+ %d = llvm.bitcast %c : f32 to vector<2xi16>
+ // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<2xi16>
+ // CHECK: llvm.return %[[BITCAST]]
+ llvm.return %d : vector<2xi16>
+}
+
+// -----
+
// CHECK-LABEL: fold_addrcast
-// CHECK-SAME: %[[a0:arg[0-9]+]]
-// CHECK-NEXT: llvm.return %[[a0]]
+// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
+// CHECK-NEXT: llvm.return %[[ARG]]
llvm.func @fold_addrcast(%x : !llvm.ptr) -> !llvm.ptr {
%c = llvm.addrspacecast %x : !llvm.ptr to !llvm.ptr
llvm.return %c : !llvm.ptr
}
+// -----
+
// CHECK-LABEL: fold_addrcast2
-// CHECK-SAME: %[[a0:arg[0-9]+]]
-// CHECK-NEXT: llvm.return %[[a0]]
+// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
+// CHECK-NEXT: llvm.return %[[ARG]]
llvm.func @fold_addrcast2(%x : !llvm.ptr) -> !llvm.ptr {
%c = llvm.addrspacecast %x : !llvm.ptr to !llvm.ptr<5>
%d = llvm.addrspacecast %c : !llvm.ptr<5> to !llvm.ptr
@@ -120,18 +140,32 @@ llvm.func @fold_addrcast2(%x : !llvm.ptr) -> !llvm.ptr {
// -----
+// CHECK-LABEL: fold_addrcast_chain
+// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
+llvm.func @fold_addrcast_chain(%x : !llvm.ptr) -> !llvm.ptr<2> {
+ %c = llvm.addrspacecast %x : !llvm.ptr to !llvm.ptr<1>
+ %d = llvm.addrspacecast %c : !llvm.ptr<1> to !llvm.ptr<2>
+ // CHECK: %[[ADDRCAST:.*]] = llvm.addrspacecast %[[ARG]] : !llvm.ptr to !llvm.ptr<2>
+ // CHECK: llvm.return %[[ADDRCAST]]
+ llvm.return %d : !llvm.ptr<2>
+}
+
+// -----
+
// CHECK-LABEL: fold_gep
-// CHECK-SAME: %[[a0:arg[0-9]+]]
-// CHECK-NEXT: llvm.return %[[a0]]
+// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
+// CHECK-NEXT: llvm.return %[[ARG]]
llvm.func @fold_gep(%x : !llvm.ptr) -> !llvm.ptr {
%c0 = arith.constant 0 : i32
%c = llvm.getelementptr %x[%c0] : (!llvm.ptr, i32) -> !llvm.ptr, i8
llvm.return %c : !llvm.ptr
}
+// -----
+
// CHECK-LABEL: fold_gep_neg
-// CHECK-SAME: %[[a0:arg[0-9]+]]
-// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr inbounds %[[a0]][0, 1]
+// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
+// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr inbounds %[[ARG]][0, 1]
// CHECK-NEXT: llvm.return %[[RES]]
llvm.func @fold_gep_neg(%x : !llvm.ptr) -> !llvm.ptr {
%c0 = arith.constant 0 : i32
@@ -139,9 +173,11 @@ llvm.func @fold_gep_neg(%x : !llvm.ptr) -> !llvm.ptr {
llvm.return %0 : !llvm.ptr
}
+// -----
+
// CHECK-LABEL: fold_gep_canon
-// CHECK-SAME: %[[a0:arg[0-9]+]]
-// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr %[[a0]][2]
+// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
+// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr %[[ARG]][2]
// CHECK-NEXT: llvm.return %[[RES]]
llvm.func @fold_gep_canon(%x : !llvm.ptr) -> !llvm.ptr {
%c2 = arith.constant 2 : i32
@@ -175,6 +211,8 @@ llvm.func @load_dce(%x : !llvm.ptr) {
llvm.return
}
+// -----
+
llvm.mlir.global external @fp() : !llvm.ptr
// CHECK-LABEL: addr_dce
@@ -184,6 +222,8 @@ llvm.func @addr_dce(%x : !llvm.ptr) {
llvm.return
}
+// -----
+
// CHECK-LABEL: alloca_dce
// CHECK-NEXT: llvm.return
llvm.func @alloca_dce() {
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 61a3d933ee1510..fa5d842302d0f4 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -793,9 +793,8 @@ llvm.func @store_int_to_vector(%arg: i32) -> vector<4xi8> {
%1 = llvm.alloca %0 x vector<2xi16> {alignment = 4 : i64} : (i32) -> !llvm.ptr
llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> vector<4xi8>
- // CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<2xi16>
- // CHECK: %[[BITCAST1:.*]] = llvm.bitcast %[[BITCAST0]] : vector<2xi16> to vector<4xi8>
- // CHECK: llvm.return %[[BITCAST1]]
+ // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<4xi8>
+ // CHECK: llvm.return %[[BITCAST]]
llvm.return %2 : vector<4xi8>
}
More information about the Mlir-commits
mailing list