[Mlir-commits] [mlir] 5161835 - [mlir][tosa] : adding folder and canonicalizer for select
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 25 09:50:33 PDT 2022
Author: lipracer
Date: 2022-03-25T16:50:29Z
New Revision: 5161835d5afa02ce7f0890251fc506c9cd20ec53
URL: https://github.com/llvm/llvm-project/commit/5161835d5afa02ce7f0890251fc506c9cd20ec53
DIFF: https://github.com/llvm/llvm-project/commit/5161835d5afa02ce7f0890251fc506c9cd20ec53.diff
LOG: [mlir][tosa] : adding folder and canonicalizer for select
define canonicalizer and folder for tosa::select
Reviewed By: mehdi_amini, Mogball
Differential Revision: https://reviews.llvm.org/D121513
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index b13f174022298..66d55d7435263 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1120,14 +1120,16 @@ def Tosa_SelectOp : Tosa_Op<"select", [
}];
let arguments = (ins
- I1Tensor:$input1,
- Tosa_Tensor:$input2,
- Tosa_Tensor:$input3
+ I1Tensor:$pred,
+ Tosa_Tensor:$on_true,
+ Tosa_Tensor:$on_false
);
let results = (outs
Tosa_Tensor:$output
);
+ let hasCanonicalizeMethod = 1;
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 37375184a4394..9fc35eaacce11 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -192,6 +192,17 @@ void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ReshapeConstOptimization>(context);
}
+LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
+ auto notOp = op.pred().getDefiningOp<tosa::LogicalNotOp>();
+ if (!notOp)
+ return failure();
+ rewriter.updateRootInPlace(op, [&]() {
+ op.getOperation()->setOperands(
+ {notOp.input1(), op.on_false(), op.on_true()});
+ });
+ return success();
+}
+
struct ConstantTransposeOptimization
: public OpRewritePattern<tosa::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
@@ -585,12 +596,15 @@ OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
return {}; \
}
-ReduceFolder(ReduceAllOp) ReduceFolder(ReduceAnyOp) ReduceFolder(ReduceMaxOp)
- ReduceFolder(ReduceMinOp) ReduceFolder(ReduceProdOp)
- ReduceFolder(ReduceSumOp)
+ReduceFolder(ReduceAllOp);
+ReduceFolder(ReduceAnyOp);
+ReduceFolder(ReduceMaxOp);
+ReduceFolder(ReduceMinOp);
+ReduceFolder(ReduceProdOp);
+ReduceFolder(ReduceSumOp);
#undef ReduceFolder
- OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
auto inputTy = input1().getType().dyn_cast<RankedTensorType>();
auto outputTy = getType().dyn_cast<RankedTensorType>();
@@ -623,6 +637,20 @@ OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
return {};
}
+OpFoldResult tosa::SelectOp::fold(ArrayRef<Attribute> operands) {
+ if (on_true() == on_false())
+ return on_true();
+
+ auto predicate = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
+ if (!predicate)
+ return {};
+
+ if (!predicate.isSplat())
+ return {};
+ return predicate.getSplatValue<APInt>().getBoolValue() ? on_true()
+ : on_false();
+}
+
OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
bool allOnes = true;
for (Attribute val : multiples().getValue()) {
@@ -1951,7 +1979,7 @@ LogicalResult WhileOp::inferReturnTypeComponents(
resultKnowledge[index],
ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
resultKnowledge[index] = meet;
- };
+ }
}
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 41303eebc0693..4f48777194580 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -252,6 +252,48 @@ func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
// -----
+// CHECK-LABEL: @select_same_value
+func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
+ %0 = "tosa.select"(%arg0, %arg1, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ // CHECK: return %arg1
+ // CHECK-NOT: tosa.select
+ return %0 : tensor<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @select_true_value
+func @select_true_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
+ %c1 = "tosa.const"() {value = dense<1> : tensor<2x3xi1>} : () -> tensor<2x3xi1>
+ %0 = "tosa.select"(%c1, %arg0, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ // CHECK: return %arg0
+ // CHECK-NOT: tosa.select
+ return %0 : tensor<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @select_false_value
+func @select_false_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
+ %c0 = "tosa.const"() {value = dense<0> : tensor<2x3xi1>} : () -> tensor<2x3xi1>
+ %0 = "tosa.select"(%c0, %arg0, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ // CHECK: return %arg1
+ // CHECK-NOT: tosa.select
+ return %0 : tensor<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @select_not_pred
+func @select_not_pred(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
+ %0 = "tosa.logical_not"(%arg0) : (tensor<2x3xi1>) -> tensor<2x3xi1>
+ %1 = "tosa.select"(%0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+ // CHECK: "tosa.select"(%arg0, %arg2, %arg1)
+ return %1 : tensor<2x3xi32>
+}
+
+// -----
+
// CHECK-LABEL: @reduce_all_fold
func @reduce_all_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg0
More information about the Mlir-commits
mailing list