[Mlir-commits] [mlir] 7872029 - [MLIR] Canonicalization of Integer Cast Operations

William S. Moses llvmlistbot at llvm.org
Sun May 2 08:27:51 PDT 2021


Author: William S. Moses
Date: 2021-05-02T11:22:18-04:00
New Revision: 78720296f3912b4095eafa5c1277646fd67bae99

URL: https://github.com/llvm/llvm-project/commit/78720296f3912b4095eafa5c1277646fd67bae99
DIFF: https://github.com/llvm/llvm-project/commit/78720296f3912b4095eafa5c1277646fd67bae99.diff

LOG: [MLIR] Canonicalization of Integer Cast Operations

1) Canonicalize IndexCast(SExt(x)) => IndexCast(x)
2) Provide constant folds of sign_extend and truncate

Differential Revision: https://reviews.llvm.org/D101714

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index aaa5a8a7256c..99b8889467e3 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1249,6 +1249,7 @@ def IndexCastOp : ArithmeticCastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
   }];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1711,6 +1712,7 @@ def SignExtendIOp : Std_Op<"sexti", [NoSideEffect,
   let printer = [{
     return printStandardCastOp(this->getOperation(), p);
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 9260e2757387..fc73947acb9d 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1241,6 +1241,29 @@ OpFoldResult IndexCastOp::fold(ArrayRef<Attribute> cstOperands) {
   return {};
 }
 
+namespace {
+///  index_cast(sign_extend x) => index_cast(x)
+struct IndexCastOfSExt : public OpRewritePattern<IndexCastOp> {
+  using OpRewritePattern<IndexCastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IndexCastOp op,
+                                PatternRewriter &rewriter) const override {
+
+    if (auto extop = op.getOperand().getDefiningOp<SignExtendIOp>()) {
+      op.setOperand(extop.getOperand());
+      return success();
+    }
+    return failure();
+  }
+};
+
+} // namespace
+
+void IndexCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                              MLIRContext *context) {
+  results.insert<IndexCastOfSExt>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // MulFOp
 //===----------------------------------------------------------------------===//
@@ -1439,6 +1462,20 @@ static LogicalResult verify(SignExtendIOp op) {
   return success();
 }
 
+OpFoldResult SignExtendIOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 1 && "unary operation takes one operand");
+
+  if (!operands[0])
+    return {};
+
+  if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
+    return IntegerAttr::get(
+        getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
+  }
+
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // SignedDivIOp
 //===----------------------------------------------------------------------===//
@@ -2686,7 +2723,18 @@ OpFoldResult TruncateIOp::fold(ArrayRef<Attribute> operands) {
       matchPattern(getOperand(), m_Op<SignExtendIOp>()))
     return getOperand().getDefiningOp()->getOperand(0);
 
-  return nullptr;
+  assert(operands.size() == 1 && "unary operation takes one operand");
+
+  if (!operands[0])
+    return {};
+
+  if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
+
+    return IntegerAttr::get(
+        getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
+  }
+
+  return {};
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 2e8a599ad162..b4942de204ab 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -698,15 +698,13 @@ func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
 
 // Check sign and zero extension and truncation of integers.
 // CHECK-LABEL: @integer_extension_and_truncation
-func @integer_extension_and_truncation() {
-// CHECK-NEXT:  %0 = llvm.mlir.constant(-3 : i3) : i3
-  %0 = constant 5 : i3
-// CHECK-NEXT: = llvm.sext %0 : i3 to i6
-  %1 = sexti %0 : i3 to i6
-// CHECK-NEXT: = llvm.zext %0 : i3 to i6
-  %2 = zexti %0 : i3 to i6
-// CHECK-NEXT: = llvm.trunc %0 : i3 to i2
-   %3 = trunci %0 : i3 to i2
+func @integer_extension_and_truncation(%arg0 : i3) {
+// CHECK-NEXT: = llvm.sext %arg0 : i3 to i6
+  %0 = sexti %arg0 : i3 to i6
+// CHECK-NEXT: = llvm.zext %arg0 : i3 to i6
+  %1 = zexti %arg0 : i3 to i6
+// CHECK-NEXT: = llvm.trunc %arg0 : i3 to i2
+   %2 = trunci %arg0 : i3 to i2
   return
 }
 

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index e2b5e7b445b9..8db3065af47d 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -399,3 +399,32 @@ func @select_cmp_ne_select(%arg0: i64, %arg1: i64) -> i64 {
   %1 = select %0, %arg0, %arg1 : i64
   return %1 : i64
 }
+
+// -----
+
+// CHECK-LABEL: @indexCastOfSignExtend
+//       CHECK:   %[[res:.+]] = index_cast %arg0 : i8 to index
+//       CHECK:   return %[[res]]
+func @indexCastOfSignExtend(%arg0: i8) -> index {
+  %ext = sexti %arg0 : i8 to i16
+  %idx = index_cast %ext : i16 to index
+  return %idx : index
+}
+
+// CHECK-LABEL: @signExtendConstant
+//       CHECK:   %[[cres:.+]] = constant -2 : i16
+//       CHECK:   return %[[cres]]
+func @signExtendConstant() -> i16 {
+  %c-2 = constant -2 : i8
+  %ext = sexti %c-2 : i8 to i16
+  return %ext : i16
+}
+
+// CHECK-LABEL: @truncConstant
+//       CHECK:   %[[cres:.+]] = constant -2 : i16
+//       CHECK:   return %[[cres]]
+func @truncConstant(%arg0: i8) -> i16 {
+  %c-2 = constant -2 : i32
+  %tr = trunci %c-2 : i32 to i16
+  return %tr : i16
+}


        


More information about the Mlir-commits mailing list