[Mlir-commits] [mlir] 5161835 - [mlir][tosa] : adding folder and canonicalizer for select

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 25 09:50:33 PDT 2022


Author: lipracer
Date: 2022-03-25T16:50:29Z
New Revision: 5161835d5afa02ce7f0890251fc506c9cd20ec53

URL: https://github.com/llvm/llvm-project/commit/5161835d5afa02ce7f0890251fc506c9cd20ec53
DIFF: https://github.com/llvm/llvm-project/commit/5161835d5afa02ce7f0890251fc506c9cd20ec53.diff

LOG: [mlir][tosa] : adding folder and canonicalizer for select

define canonicalizer and folder for tosa::select

Reviewed By: mehdi_amini, Mogball

Differential Revision: https://reviews.llvm.org/D121513

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index b13f174022298..66d55d7435263 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1120,14 +1120,16 @@ def Tosa_SelectOp : Tosa_Op<"select", [
   }];
 
   let arguments = (ins
-    I1Tensor:$input1,
-    Tosa_Tensor:$input2,
-    Tosa_Tensor:$input3
+    I1Tensor:$pred,
+    Tosa_Tensor:$on_true,
+    Tosa_Tensor:$on_false
   );
 
   let results = (outs
     Tosa_Tensor:$output
   );
+  let hasCanonicalizeMethod = 1;
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 37375184a4394..9fc35eaacce11 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -192,6 +192,17 @@ void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ReshapeConstOptimization>(context);
 }
 
+LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
+  auto notOp = op.pred().getDefiningOp<tosa::LogicalNotOp>();
+  if (!notOp)
+    return failure();
+  rewriter.updateRootInPlace(op, [&]() {
+    op.getOperation()->setOperands(
+        {notOp.input1(), op.on_false(), op.on_true()});
+  });
+  return success();
+}
+
 struct ConstantTransposeOptimization
     : public OpRewritePattern<tosa::TransposeOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -585,12 +596,15 @@ OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
     return {};                                                                 \
   }
 
-ReduceFolder(ReduceAllOp) ReduceFolder(ReduceAnyOp) ReduceFolder(ReduceMaxOp)
-    ReduceFolder(ReduceMinOp) ReduceFolder(ReduceProdOp)
-        ReduceFolder(ReduceSumOp)
+ReduceFolder(ReduceAllOp);
+ReduceFolder(ReduceAnyOp);
+ReduceFolder(ReduceMaxOp);
+ReduceFolder(ReduceMinOp);
+ReduceFolder(ReduceProdOp);
+ReduceFolder(ReduceSumOp);
 #undef ReduceFolder
 
-            OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
   auto inputTy = input1().getType().dyn_cast<RankedTensorType>();
   auto outputTy = getType().dyn_cast<RankedTensorType>();
 
@@ -623,6 +637,20 @@ OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
+OpFoldResult tosa::SelectOp::fold(ArrayRef<Attribute> operands) {
+  if (on_true() == on_false())
+    return on_true();
+
+  auto predicate = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
+  if (!predicate)
+    return {};
+
+  if (!predicate.isSplat())
+    return {};
+  return predicate.getSplatValue<APInt>().getBoolValue() ? on_true()
+                                                         : on_false();
+}
+
 OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
   bool allOnes = true;
   for (Attribute val : multiples().getValue()) {
@@ -1951,7 +1979,7 @@ LogicalResult WhileOp::inferReturnTypeComponents(
               resultKnowledge[index],
               ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
         resultKnowledge[index] = meet;
-      };
+      }
     }
   }
 

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 41303eebc0693..4f48777194580 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -252,6 +252,48 @@ func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
 
 // -----
 
+// CHECK-LABEL: @select_same_value
+func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
+  %0 = "tosa.select"(%arg0, %arg1, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  // CHECK: return %arg1
+  // CHECK-NOT: tosa.select
+  return %0 : tensor<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @select_true_value
+func @select_true_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
+  %c1 = "tosa.const"() {value = dense<1> : tensor<2x3xi1>} : () -> tensor<2x3xi1>
+  %0 = "tosa.select"(%c1, %arg0, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  // CHECK: return %arg0
+  // CHECK-NOT: tosa.select
+  return %0 : tensor<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @select_false_value
+func @select_false_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
+  %c0 = "tosa.const"() {value = dense<0> : tensor<2x3xi1>} : () -> tensor<2x3xi1>
+  %0 = "tosa.select"(%c0, %arg0, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  // CHECK: return %arg1
+  // CHECK-NOT: tosa.select
+  return %0 : tensor<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @select_not_pred
+func @select_not_pred(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
+  %0 = "tosa.logical_not"(%arg0) : (tensor<2x3xi1>) -> tensor<2x3xi1>
+  %1 = "tosa.select"(%0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  // CHECK: "tosa.select"(%arg0, %arg2, %arg1)
+  return %1 : tensor<2x3xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @reduce_all_fold
 func @reduce_all_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg0


        


More information about the Mlir-commits mailing list