[Mlir-commits] [mlir] e5f2898 - [MLIR][STD] Fold trunci (zexti).

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 27 11:40:46 PDT 2021


Author: KareemErgawy-TomTom
Date: 2021-03-27T19:40:10+01:00
New Revision: e5f2898bc751aab581193ad87cf887e5c4c8bcec

URL: https://github.com/llvm/llvm-project/commit/e5f2898bc751aab581193ad87cf887e5c4c8bcec
DIFF: https://github.com/llvm/llvm-project/commit/e5f2898bc751aab581193ad87cf887e5c4c8bcec.diff

LOG: [MLIR][STD] Fold trunci (zexti).

This patch folds the following pattern:

```
  %arg0 = ...
  %0 = zexti %arg0 : i1 to i8
  %1 = trunci %0 : i8 to i1
```

into just `%arg0`.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D99453

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 84c152b351a0..fcfe8f1850e9 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2070,6 +2070,8 @@ def TruncateIOp : Std_Op<"trunci", [NoSideEffect,
   let printer = [{
     return printStandardCastOp(this->getOperation(), p);
   }];
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 2f2f36e502d6..4b53bf47b6e5 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2181,6 +2181,14 @@ static LogicalResult verify(TruncateIOp op) {
   return success();
 }
 
+OpFoldResult TruncateIOp::fold(ArrayRef<Attribute> operands) {
+  // trunci(zexti(a)) -> a
+  if (matchPattern(getOperand(), m_Op<ZeroExtendIOp>()))
+    return getOperand().getDefiningOp()->getOperand(0);
+
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // UnsignedDivIOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index a65c46452cc8..fdf6f880ffec 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1059,3 +1059,53 @@ func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index)
   return %2 : tensor<?x?x?xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func @fold_trunci
+// CHECK-SAME:    (%[[ARG0:[0-9a-z]*]]: i1)
+func @fold_trunci(%arg0: i1) -> i1 attributes {} {
+  // CHECK-NEXT: return %[[ARG0]] : i1
+  %0 = zexti %arg0 : i1 to i8
+  %1 = trunci %0 : i8 to i1
+  return %1 : i1
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_trunci_vector
+// CHECK-SAME:    (%[[ARG0:[0-9a-z]*]]: vector<4xi1>)
+func @fold_trunci_vector(%arg0: vector<4xi1>) -> vector<4xi1> attributes {} {
+  // CHECK-NEXT: return %[[ARG0]] : vector<4xi1>
+  %0 = zexti %arg0 : vector<4xi1> to vector<4xi8>
+  %1 = trunci %0 : vector<4xi8> to vector<4xi1>
+  return %1 : vector<4xi1>
+}
+
+// -----
+
+// TODO Canonicalize this into:
+//   zexti %arg0 : i1 to i2
+
+// CHECK-LABEL: func @do_not_fold_trunci
+// CHECK-SAME:    (%[[ARG0:[0-9a-z]*]]: i1)
+func @do_not_fold_trunci(%arg0: i1) -> i2 attributes {} {
+  // CHECK-NEXT: zexti %[[ARG0]] : i1 to i8
+  // CHECK-NEXT: %[[RES:[0-9a-z]*]] = trunci %{{.*}} : i8 to i2
+  // CHECK-NEXT: return %[[RES]] : i2
+  %0 = zexti %arg0 : i1 to i8
+  %1 = trunci %0 : i8 to i2
+  return %1 : i2
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_fold_trunci_vector
+// CHECK-SAME:    (%[[ARG0:[0-9a-z]*]]: vector<4xi1>)
+func @do_not_fold_trunci_vector(%arg0: vector<4xi1>) -> vector<4xi2> attributes {} {
+  // CHECK-NEXT: zexti %[[ARG0]] : vector<4xi1> to vector<4xi8>
+  // CHECK-NEXT: %[[RES:[0-9a-z]*]] = trunci %{{.*}} : vector<4xi8> to vector<4xi2>
+  // CHECK-NEXT: return %[[RES]] : vector<4xi2>
+  %0 = zexti %arg0 : vector<4xi1> to vector<4xi8>
+  %1 = trunci %0 : vector<4xi8> to vector<4xi2>
+  return %1 : vector<4xi2>
+}


        


More information about the Mlir-commits mailing list