[Mlir-commits] [mlir] 62084b5 - [mlir][Tosa][NFC] Migrate Tosa dialect to the new fold API

Markus Böck llvmlistbot at llvm.org
Thu Jan 12 00:52:23 PST 2023


Author: Markus Böck
Date: 2023-01-12T09:52:14+01:00
New Revision: 62084b5f37c84d12a15a0a5ebfe41c96d2090b6a

URL: https://github.com/llvm/llvm-project/commit/62084b5f37c84d12a15a0a5ebfe41c96d2090b6a
DIFF: https://github.com/llvm/llvm-project/commit/62084b5f37c84d12a15a0a5ebfe41c96d2090b6a.diff

LOG: [mlir][Tosa][NFC] Migrate Tosa dialect to the new fold API

See https://discourse.llvm.org/t/psa-new-improved-fold-method-signature-has-landed-please-update-your-downstream-projects/67618 for context

Differential Revision: https://reviews.llvm.org/D141527

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 110a334011021..1960e47632eb7 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -45,6 +45,7 @@ def Tosa_Dialect : Dialect {
   let cppNamespace = "mlir::tosa";
   let hasConstantMaterializer = 1;
   let useDefaultAttributePrinterParser = 1;
+  let useFoldAPI = kEmitFoldAdaptorFolder;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 5f44634d5482d..625c85593a9cf 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -539,7 +539,7 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
   return {};
 }
 
-OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
   auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
   auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();
   auto resultTy = getType().dyn_cast<RankedTensorType>();
@@ -549,8 +549,8 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
     return {};
 
   auto resultETy = resultTy.getElementType();
-  auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
-  auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+  auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
+  auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
 
   if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<FloatType>()) {
     if (lhsAttr.getSplatValue<APFloat>().isZero())
@@ -579,7 +579,7 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
                                                             lhsTy);
 }
 
-OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
   auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
   auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();
   auto resultTy = getType().dyn_cast<RankedTensorType>();
@@ -589,8 +589,8 @@ OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
     return {};
 
   auto resultETy = resultTy.getElementType();
-  auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
-  auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+  auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
+  auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
   if (lhsAttr && lhsAttr.isSplat()) {
     if (resultETy.isa<IntegerType>() && lhsAttr.getSplatValue<APInt>().isZero())
       return lhsAttr;
@@ -646,7 +646,7 @@ DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
 }
 } // namespace
 
-OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
   auto lhs = getInput1();
   auto rhs = getInput2();
   auto lhsTy = lhs.getType().dyn_cast<RankedTensorType>();
@@ -658,8 +658,8 @@ OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
     return {};
 
   auto resultETy = resultTy.getElementType();
-  auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
-  auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+  auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
+  auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
 
   if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<FloatType>()) {
     auto val = lhsAttr.getSplatValue<APFloat>();
@@ -700,7 +700,7 @@ OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
   return mulBinaryFolder(lhsAttr, rhsAttr, lhsTy, getShift());
 }
 
-OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
   auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
   auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();
   auto resultTy = getType().dyn_cast<RankedTensorType>();
@@ -710,8 +710,8 @@ OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
     return {};
 
   auto resultETy = resultTy.getElementType();
-  auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
-  auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+  auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
+  auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
 
   if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
     if (rhsAttr.getSplatValue<APFloat>().isZero())
@@ -757,10 +757,10 @@ struct APIntFoldGreaterEqual {
 };
 } // namespace
 
-OpFoldResult GreaterOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
   auto resultTy = getType().dyn_cast<RankedTensorType>();
-  auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
-  auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+  auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
+  auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
 
   if (!lhsAttr || !rhsAttr)
     return {};
@@ -769,10 +769,10 @@ OpFoldResult GreaterOp::fold(ArrayRef<Attribute> operands) {
       lhsAttr, rhsAttr, resultTy);
 }
 
-OpFoldResult GreaterEqualOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
   auto resultTy = getType().dyn_cast<RankedTensorType>();
-  auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
-  auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+  auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
+  auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
 
   if (!lhsAttr || !rhsAttr)
     return {};
@@ -782,10 +782,10 @@ OpFoldResult GreaterEqualOp::fold(ArrayRef<Attribute> operands) {
       lhsAttr, rhsAttr, resultTy);
 }
 
-OpFoldResult EqualOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
   auto resultTy = getType().dyn_cast<RankedTensorType>();
-  auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
-  auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+  auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
+  auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
   Value lhs = getInput1();
   Value rhs = getInput2();
   auto lhsTy = lhs.getType().cast<ShapedType>();
@@ -805,11 +805,11 @@ OpFoldResult EqualOp::fold(ArrayRef<Attribute> operands) {
                                                               resultTy);
 }
 
-OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
   if (getInput().getType() == getType())
     return getInput();
 
-  auto operand = operands[0].dyn_cast_or_null<ElementsAttr>();
+  auto operand = adaptor.getInput().dyn_cast_or_null<ElementsAttr>();
   if (!operand)
     return {};
 
@@ -868,13 +868,10 @@ OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
-OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.empty() && "constant has no operands");
-  return getValueAttr();
-}
+OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
 
 #define REDUCE_FOLDER(OP)                                                      \
-  OpFoldResult OP::fold(ArrayRef<Attribute> operands) {                        \
+  OpFoldResult OP::fold(FoldAdaptor adaptor) {                                 \
     ShapedType inputTy = getInput().getType().cast<ShapedType>();              \
     if (!inputTy.hasRank())                                                    \
       return {};                                                               \
@@ -891,7 +888,7 @@ REDUCE_FOLDER(ReduceProdOp)
 REDUCE_FOLDER(ReduceSumOp)
 #undef REDUCE_FOLDER
 
-OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
   auto inputTy = getInput1().getType().dyn_cast<RankedTensorType>();
   auto outputTy = getType().dyn_cast<RankedTensorType>();
 
@@ -901,7 +898,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
   if (inputTy == outputTy)
     return getInput1();
 
-  auto operand = operands[0].dyn_cast_or_null<DenseElementsAttr>();
+  auto operand = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
   if (operand && outputTy.hasStaticShape() && operand.isSplat()) {
     return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
   }
@@ -909,10 +906,10 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
-OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
   // If the pad is all zeros we can fold this operation away.
-  if (operands[1]) {
-    auto densePad = operands[1].cast<DenseElementsAttr>();
+  if (adaptor.getPadding()) {
+    auto densePad = adaptor.getPadding().cast<DenseElementsAttr>();
     if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
       return getInput1();
     }
@@ -923,7 +920,7 @@ OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
 
 // Fold away cases where a tosa.resize operation returns a copy
 // of the input image.
-OpFoldResult ResizeOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
   ArrayRef<int64_t> offset = getOffset();
   ArrayRef<int64_t> border = getBorder();
   ArrayRef<int64_t> scale = getScale();
@@ -952,11 +949,11 @@ OpFoldResult ResizeOp::fold(ArrayRef<Attribute> operands) {
   return input;
 }
 
-OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
   auto operand = getInput();
   auto operandTy = operand.getType().cast<ShapedType>();
   auto axis = getAxis();
-  auto operandAttr = operands[0].dyn_cast_or_null<SplatElementsAttr>();
+  auto operandAttr = adaptor.getInput().dyn_cast_or_null<SplatElementsAttr>();
   if (operandAttr)
     return operandAttr;
 
@@ -967,7 +964,7 @@ OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
-OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
   auto inputTy = getInput().getType().dyn_cast<RankedTensorType>();
   auto outputTy = getType().dyn_cast<RankedTensorType>();
 
@@ -977,10 +974,10 @@ OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
   if (inputTy == outputTy && inputTy.hasStaticShape())
     return getInput();
 
-  if (!operands[0])
+  if (!adaptor.getInput())
     return {};
 
-  auto operand = operands[0].cast<ElementsAttr>();
+  auto operand = adaptor.getInput().cast<ElementsAttr>();
   if (operand.isSplat() && outputTy.hasStaticShape()) {
     return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
   }
@@ -995,11 +992,11 @@ OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
-OpFoldResult tosa::SelectOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
   if (getOnTrue() == getOnFalse())
     return getOnTrue();
 
-  auto predicate = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
+  auto predicate = adaptor.getPred().dyn_cast_or_null<DenseIntElementsAttr>();
   if (!predicate)
     return {};
 
@@ -1009,19 +1006,19 @@ OpFoldResult tosa::SelectOp::fold(ArrayRef<Attribute> operands) {
                                                          : getOnFalse();
 }
 
-OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
   bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
   if (allOnes && getInput1().getType() == getType())
     return getInput1();
   return {};
 }
 
-OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
   auto inputTy = getInput1().getType().cast<ShapedType>();
   auto resultTy = getType().cast<ShapedType>();
 
   // Transposing splat values just means reshaping.
-  if (auto input = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
+  if (auto input = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>()) {
     if (input.isSplat() && resultTy.hasStaticShape() &&
         inputTy.getElementType() == resultTy.getElementType())
       return input.reshape(resultTy);


        


More information about the Mlir-commits mailing list