[Mlir-commits] [mlir] ab85aec - [mlir][arith] Add missing canon pattern `trunci(ext*i(x)) -> ext*i(x)`

Jakub Kuderski llvmlistbot at llvm.org
Thu Apr 27 08:22:26 PDT 2023


Author: Jakub Kuderski
Date: 2023-04-27T11:21:59-04:00
New Revision: ab85aec1affc92647c195f736d1bac69976baeb8

URL: https://github.com/llvm/llvm-project/commit/ab85aec1affc92647c195f736d1bac69976baeb8
DIFF: https://github.com/llvm/llvm-project/commit/ab85aec1affc92647c195f736d1bac69976baeb8.diff

LOG: [mlir][arith] Add missing canon pattern `trunci(ext*i(x)) -> ext*i(x)`

This pattern triggers when only the extension bits are truncated.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D149286

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/Dialect/Arith/canonicalize.mlir
    mlir/test/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index d4c6b8184751f..ba1f3f8bd1d86 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -319,6 +319,20 @@ def TruncationMatchesShiftAmount :
       CPred<"(getScalarOrElementWidth($0) - getScalarOrElementWidth($1)) == "
               "*getIntOrSplatIntValue($2)">]>>;
 
+// trunci(extsi(x)) -> extsi(x), when only the sign-extension bits are truncated
+def TruncIExtSIToExtSI :
+    Pat<(Arith_TruncIOp:$tr (Arith_ExtSIOp:$ext $x)),
+        (Arith_ExtSIOp $x),
+        [(ValueWiderThan $ext, $tr),
+         (ValueWiderThan $tr, $x)]>;
+
+// trunci(extui(x)) -> extui(x), when only the zero-extension bits are truncated
+def TruncIExtUIToExtUI :
+    Pat<(Arith_TruncIOp:$tr (Arith_ExtUIOp:$ext $x)),
+        (Arith_ExtUIOp $x),
+        [(ValueWiderThan $ext, $tr),
+         (ValueWiderThan $tr, $x)]>;
+
 // trunci(shrsi(x, c)) -> trunci(shrui(x, c))
 def TruncIShrSIToTrunciShrUI :
     Pat<(Arith_TruncIOp:$tr

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 446bb6461077d..b4b0572fdee75 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1290,8 +1290,9 @@ bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 
 void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                   MLIRContext *context) {
-  patterns.add<TruncIShrSIToTrunciShrUI, TruncIShrUIMulIToMulSIExtended,
-               TruncIShrUIMulIToMulUIExtended>(context);
+  patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
+               TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
+      context);
 }
 
 LogicalResult arith::TruncIOp::verify() {

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 1f96876e2cf2c..14589b2915e94 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -629,6 +629,16 @@ func.func @truncExtui2(%arg0: i32) -> i16 {
   return %trunci : i16
 }
 
+// CHECK-LABEL: @truncExtui3
+//       CHECK:  %[[ARG0:.+]]: i8
+//       CHECK:  %[[CST:.*]] = arith.extui %[[ARG0:.+]] : i8 to i16
+//       CHECK:   return  %[[CST:.*]] : i16
+func.func @truncExtui3(%arg0: i8) -> i16 {
+  %extui = arith.extui %arg0 : i8 to i32
+  %trunci = arith.trunci %extui : i32 to i16
+  return %trunci : i16
+}
+
 // CHECK-LABEL: @truncExtuiVector
 //       CHECK:  %[[ARG0:.+]]: vector<2xi32>
 //       CHECK:  %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : vector<2xi32> to vector<2xi16>
@@ -658,6 +668,16 @@ func.func @truncExtsi2(%arg0: i32) -> i16 {
   return %trunci : i16
 }
 
+// CHECK-LABEL: @truncExtsi3
+//       CHECK:  %[[ARG0:.+]]: i8
+//       CHECK:  %[[CST:.*]] = arith.extsi %[[ARG0:.+]] : i8 to i16
+//       CHECK:   return  %[[CST:.*]] : i16
+func.func @truncExtsi3(%arg0: i8) -> i16 {
+  %extsi = arith.extsi %arg0 : i8 to i32
+  %trunci = arith.trunci %extsi : i32 to i16
+  return %trunci : i16
+}
+
 // CHECK-LABEL: @truncExtsiVector
 //       CHECK:  %[[ARG0:.+]]: vector<2xi32>
 //       CHECK:  %[[CST:.*]] = arith.trunci %[[ARG0:.+]] : vector<2xi32> to vector<2xi16>

diff  --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 5cc0eb539ecf5..47a19bb598c25 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1107,14 +1107,10 @@ func.func @fold_trunci_vector(%arg0: vector<4xi1>) -> vector<4xi1> attributes {}
 
 // -----
 
-// TODO Canonicalize this into:
-//   arith.extui %arg0 : i1 to i2
-
-// CHECK-LABEL: func @do_not_fold_trunci
+// CHECK-LABEL: func @fold_trunci
 // CHECK-SAME:    (%[[ARG0:[0-9a-z]*]]: i1)
-func.func @do_not_fold_trunci(%arg0: i1) -> i2 attributes {} {
-  // CHECK-NEXT: arith.extui %[[ARG0]] : i1 to i8
-  // CHECK-NEXT: %[[RES:[0-9a-z]*]] = arith.trunci %{{.*}} : i8 to i2
+func.func @fold_trunci(%arg0: i1) -> i2 attributes {} {
+  // CHECK-NEXT: %[[RES:[0-9a-z]*]] = arith.extui %[[ARG0]] : i1 to i2
   // CHECK-NEXT: return %[[RES]] : i2
   %0 = arith.extui %arg0 : i1 to i8
   %1 = arith.trunci %0 : i8 to i2
@@ -1123,11 +1119,10 @@ func.func @do_not_fold_trunci(%arg0: i1) -> i2 attributes {} {
 
 // -----
 
-// CHECK-LABEL: func @do_not_fold_trunci_vector
+// CHECK-LABEL: func @fold_trunci_vector
 // CHECK-SAME:    (%[[ARG0:[0-9a-z]*]]: vector<4xi1>)
-func.func @do_not_fold_trunci_vector(%arg0: vector<4xi1>) -> vector<4xi2> attributes {} {
-  // CHECK-NEXT: arith.extui %[[ARG0]] : vector<4xi1> to vector<4xi8>
-  // CHECK-NEXT: %[[RES:[0-9a-z]*]] = arith.trunci %{{.*}} : vector<4xi8> to vector<4xi2>
+func.func @fold_trunci_vector(%arg0: vector<4xi1>) -> vector<4xi2> attributes {} {
+  // CHECK-NEXT: %[[RES:[0-9a-z]*]] = arith.extui %[[ARG0]] : vector<4xi1> to vector<4xi2>
   // CHECK-NEXT: return %[[RES]] : vector<4xi2>
   %0 = arith.extui %arg0 : vector<4xi1> to vector<4xi8>
   %1 = arith.trunci %0 : vector<4xi8> to vector<4xi2>


        


More information about the Mlir-commits mailing list