[Mlir-commits] [mlir] d57479c - [mlir][tosa] Update SelectOp's input names to match TOSA specification (#127833)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 19 10:20:00 PST 2025


Author: Jerry-Ge
Date: 2025-02-19T10:19:57-08:00
New Revision: d57479cfbe9a6b4dccedfd1221c04973ad90ec97

URL: https://github.com/llvm/llvm-project/commit/d57479cfbe9a6b4dccedfd1221c04973ad90ec97
DIFF: https://github.com/llvm/llvm-project/commit/d57479cfbe9a6b4dccedfd1221c04973ad90ec97.diff

LOG: [mlir][tosa] Update SelectOp's input names to match TOSA specification (#127833)

Updated:
- pred to input1
- on_true to input2
- on_false to input3

Signed-off-by: Jerry Ge <jerry.ge at arm.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 4d5837ca26c91..7cdf79f4dc59d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1190,9 +1190,9 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
   }];
 
   let arguments = (ins
-    Tosa_I1Tensor:$pred,
-    Tosa_Tensor:$on_true,
-    Tosa_Tensor:$on_false
+    Tosa_I1Tensor:$input1,
+    Tosa_Tensor:$input2,
+    Tosa_Tensor:$input3
   );
 
   let results = (outs
@@ -1202,7 +1202,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
   let hasFolder = 1;
 
   let assemblyFormat = [{
-    operands attr-dict `:` `(` type($pred) `,` type($on_true) `,` type($on_false)
+    operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
     `)` `->` type($output)
   }];
 }

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index b9bcedb7fe71d..9bfc2aae1d6a5 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -65,12 +65,12 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
-  auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
+  auto notOp = op.getInput1().getDefiningOp<tosa::LogicalNotOp>();
   if (!notOp)
     return failure();
   rewriter.modifyOpInPlace(op, [&]() {
     op.getOperation()->setOperands(
-        {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
+        {notOp.getInput1(), op.getInput3(), op.getInput2()});
   });
   return success();
 }
@@ -1131,18 +1131,18 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
 }
 
 OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
-  if (getOnTrue() == getOnFalse())
-    return getOnTrue();
+  if (getInput2() == getInput3())
+    return getInput2();
 
   auto predicate =
-      llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
+      llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
   if (!predicate)
     return {};
 
   if (!predicate.isSplat())
     return {};
-  return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
-                                                         : getOnFalse();
+  return predicate.getSplatValue<APInt>().getBoolValue() ? getInput2()
+                                                         : getInput3();
 }
 
 OpFoldResult TileOp::fold(FoldAdaptor adaptor) {

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index 79afc75fd6c8e..87b2a2695351b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -169,9 +169,9 @@ struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
   LogicalResult matchAndRewrite(tosa::SelectOp tosaOp,
                                 PatternRewriter &rewriter) const override {
 
-    Value input1 = tosaOp.getPred();
-    Value input2 = tosaOp.getOnTrue();
-    Value input3 = tosaOp.getOnFalse();
+    Value input1 = tosaOp.getInput1();
+    Value input2 = tosaOp.getInput2();
+    Value input3 = tosaOp.getInput3();
     Value output = tosaOp.getResult();
 
     auto outputType = dyn_cast<RankedTensorType>(output.getType());


        


More information about the Mlir-commits mailing list