[Mlir-commits] [mlir] dd39f9b - [MLIR][Arith] Fold trunci with ext if the bit width of the input type of ext is greater than the
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 6 06:13:54 PDT 2023
Author: liqinweng
Date: 2023-04-06T21:08:39+08:00
New Revision: dd39f9b418379264ceb6a232dc0b2a5fb18a4203
URL: https://github.com/llvm/llvm-project/commit/dd39f9b418379264ceb6a232dc0b2a5fb18a4203
DIFF: https://github.com/llvm/llvm-project/commit/dd39f9b418379264ceb6a232dc0b2a5fb18a4203.diff
LOG: [MLIR][Arith] Fold trunci with ext if the bit width of the input type of ext is greater than the
This patch is mainly to deal with folding trunci with ext,as flows:
trunci(zexti(a)) -> trunci(a)
trunci(zexti(a)) -> trunci(a)
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D140604
Added:
Modified:
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d7ce71a279f59..e203dbc847339 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1245,11 +1245,22 @@ LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
//===----------------------------------------------------------------------===//
OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
- // trunci(zexti(a)) -> a
- // trunci(sexti(a)) -> a
if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
- matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
- return getOperand().getDefiningOp()->getOperand(0);
+ matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
+ Value src = getOperand().getDefiningOp()->getOperand(0);
+ Type srcType = getElementTypeOrSelf(src.getType());
+ Type dstType = getElementTypeOrSelf(getType());
+ // trunci(zexti(a)) -> trunci(a)
+ // trunci(sexti(a)) -> trunci(a)
+ if (srcType.cast<IntegerType>().getWidth() >
+ dstType.cast<IntegerType>().getWidth()) {
+ setOperand(src);
+ return getResult();
+ }
+ // trunci(zexti(a)) -> a
+ // trunci(sexti(a)) -> a
+ return src;
+ }
// trunci(trunci(a)) -> trunci(a))
if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 0170620770823..1f96876e2cf2c 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -619,6 +619,26 @@ func.func @truncExtui(%arg0: i32) -> i32 {
return %trunci : i32
}
+// CHECK-LABEL: @truncExtui2
+// CHECK: %[[ARG0:.+]]: i32
+// CHECK: %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : i32 to i16
+// CHECK: return %[[CST:.*]]
+func.func @truncExtui2(%arg0: i32) -> i16 {
+ %extui = arith.extui %arg0 : i32 to i64
+ %trunci = arith.trunci %extui : i64 to i16
+ return %trunci : i16
+}
+
+// CHECK-LABEL: @truncExtuiVector
+// CHECK: %[[ARG0:.+]]: vector<2xi32>
+// CHECK: %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : vector<2xi32> to vector<2xi16>
+// CHECK: return %[[CST:.*]]
+func.func @truncExtuiVector(%arg0: vector<2xi32>) -> vector<2xi16> {
+ %extsi = arith.extui %arg0 : vector<2xi32> to vector<2xi64>
+ %trunci = arith.trunci %extsi : vector<2xi64> to vector<2xi16>
+ return %trunci : vector<2xi16>
+}
+
// CHECK-LABEL: @truncExtsi
// CHECK-NOT: trunci
// CHECK: return %arg0
@@ -628,6 +648,26 @@ func.func @truncExtsi(%arg0: i32) -> i32 {
return %trunci : i32
}
+// CHECK-LABEL: @truncExtsi2
+// CHECK: %[[ARG0:.+]]: i32
+// CHECK: %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : i32 to i16
+// CHECK: return %[[CST:.*]]
+func.func @truncExtsi2(%arg0: i32) -> i16 {
+ %extsi = arith.extsi %arg0 : i32 to i64
+ %trunci = arith.trunci %extsi : i64 to i16
+ return %trunci : i16
+}
+
+// CHECK-LABEL: @truncExtsiVector
+// CHECK: %[[ARG0:.+]]: vector<2xi32>
+// CHECK: %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : vector<2xi32> to vector<2xi16>
+// CHECK: return %[[CST:.*]]
+func.func @truncExtsiVector(%arg0: vector<2xi32>) -> vector<2xi16> {
+ %extsi = arith.extsi %arg0 : vector<2xi32> to vector<2xi64>
+ %trunci = arith.trunci %extsi : vector<2xi64> to vector<2xi16>
+ return %trunci : vector<2xi16>
+}
+
// CHECK-LABEL: @truncConstantSplat
// CHECK: %[[cres:.+]] = arith.constant dense<-2> : vector<4xi8>
// CHECK: return %[[cres]]
More information about the Mlir-commits
mailing list