[Mlir-commits] [mlir] 7872029 - [MLIR] Canonicalization of Integer Cast Operations
William S. Moses
llvmlistbot at llvm.org
Sun May 2 08:27:51 PDT 2021
Author: William S. Moses
Date: 2021-05-02T11:22:18-04:00
New Revision: 78720296f3912b4095eafa5c1277646fd67bae99
URL: https://github.com/llvm/llvm-project/commit/78720296f3912b4095eafa5c1277646fd67bae99
DIFF: https://github.com/llvm/llvm-project/commit/78720296f3912b4095eafa5c1277646fd67bae99.diff
LOG: [MLIR] Canonicalization of Integer Cast Operations
1) Canonicalize IndexCast(SExt(x)) => IndexCast(x)
2) Provide constant folds of sign_extend and truncate
Differential Revision: https://reviews.llvm.org/D101714
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
mlir/test/Dialect/Standard/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index aaa5a8a7256c..99b8889467e3 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1249,6 +1249,7 @@ def IndexCastOp : ArithmeticCastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
}];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
@@ -1711,6 +1712,7 @@ def SignExtendIOp : Std_Op<"sexti", [NoSideEffect,
let printer = [{
return printStandardCastOp(this->getOperation(), p);
}];
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 9260e2757387..fc73947acb9d 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1241,6 +1241,29 @@ OpFoldResult IndexCastOp::fold(ArrayRef<Attribute> cstOperands) {
return {};
}
+namespace {
+/// index_cast(sign_extend x) => index_cast(x)
+struct IndexCastOfSExt : public OpRewritePattern<IndexCastOp> {
+ using OpRewritePattern<IndexCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IndexCastOp op,
+ PatternRewriter &rewriter) const override {
+
+ if (auto extop = op.getOperand().getDefiningOp<SignExtendIOp>()) {
+ op.setOperand(extop.getOperand());
+ return success();
+ }
+ return failure();
+ }
+};
+
+} // namespace
+
+void IndexCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<IndexCastOfSExt>(context);
+}
+
//===----------------------------------------------------------------------===//
// MulFOp
//===----------------------------------------------------------------------===//
@@ -1439,6 +1462,20 @@ static LogicalResult verify(SignExtendIOp op) {
return success();
}
+OpFoldResult SignExtendIOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 1 && "unary operation takes one operand");
+
+ if (!operands[0])
+ return {};
+
+ if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
+ return IntegerAttr::get(
+ getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
+ }
+
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// SignedDivIOp
//===----------------------------------------------------------------------===//
@@ -2686,7 +2723,18 @@ OpFoldResult TruncateIOp::fold(ArrayRef<Attribute> operands) {
matchPattern(getOperand(), m_Op<SignExtendIOp>()))
return getOperand().getDefiningOp()->getOperand(0);
- return nullptr;
+ assert(operands.size() == 1 && "unary operation takes one operand");
+
+ if (!operands[0])
+ return {};
+
+ if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
+
+ return IntegerAttr::get(
+ getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
+ }
+
+ return {};
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 2e8a599ad162..b4942de204ab 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -698,15 +698,13 @@ func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
// Check sign and zero extension and truncation of integers.
// CHECK-LABEL: @integer_extension_and_truncation
-func @integer_extension_and_truncation() {
-// CHECK-NEXT: %0 = llvm.mlir.constant(-3 : i3) : i3
- %0 = constant 5 : i3
-// CHECK-NEXT: = llvm.sext %0 : i3 to i6
- %1 = sexti %0 : i3 to i6
-// CHECK-NEXT: = llvm.zext %0 : i3 to i6
- %2 = zexti %0 : i3 to i6
-// CHECK-NEXT: = llvm.trunc %0 : i3 to i2
- %3 = trunci %0 : i3 to i2
+func @integer_extension_and_truncation(%arg0 : i3) {
+// CHECK-NEXT: = llvm.sext %arg0 : i3 to i6
+ %0 = sexti %arg0 : i3 to i6
+// CHECK-NEXT: = llvm.zext %arg0 : i3 to i6
+ %1 = zexti %arg0 : i3 to i6
+// CHECK-NEXT: = llvm.trunc %arg0 : i3 to i2
+ %2 = trunci %arg0 : i3 to i2
return
}
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index e2b5e7b445b9..8db3065af47d 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -399,3 +399,32 @@ func @select_cmp_ne_select(%arg0: i64, %arg1: i64) -> i64 {
%1 = select %0, %arg0, %arg1 : i64
return %1 : i64
}
+
+// -----
+
+// CHECK-LABEL: @indexCastOfSignExtend
+// CHECK: %[[res:.+]] = index_cast %arg0 : i8 to index
+// CHECK: return %[[res]]
+func @indexCastOfSignExtend(%arg0: i8) -> index {
+ %ext = sexti %arg0 : i8 to i16
+ %idx = index_cast %ext : i16 to index
+ return %idx : index
+}
+
+// CHECK-LABEL: @signExtendConstant
+// CHECK: %[[cres:.+]] = constant -2 : i16
+// CHECK: return %[[cres]]
+func @signExtendConstant() -> i16 {
+ %c-2 = constant -2 : i8
+ %ext = sexti %c-2 : i8 to i16
+ return %ext : i16
+}
+
+// CHECK-LABEL: @truncConstant
+// CHECK: %[[cres:.+]] = constant -2 : i16
+// CHECK: return %[[cres]]
+func @truncConstant(%arg0: i8) -> i16 {
+ %c-2 = constant -2 : i32
+ %tr = trunci %c-2 : i32 to i16
+ return %tr : i16
+}
More information about the Mlir-commits
mailing list