[Mlir-commits] [mlir] 62084b5 - [mlir][Tosa][NFC] Migrate Tosa dialect to the new fold API
Markus Böck
llvmlistbot at llvm.org
Thu Jan 12 00:52:23 PST 2023
Author: Markus Böck
Date: 2023-01-12T09:52:14+01:00
New Revision: 62084b5f37c84d12a15a0a5ebfe41c96d2090b6a
URL: https://github.com/llvm/llvm-project/commit/62084b5f37c84d12a15a0a5ebfe41c96d2090b6a
DIFF: https://github.com/llvm/llvm-project/commit/62084b5f37c84d12a15a0a5ebfe41c96d2090b6a.diff
LOG: [mlir][Tosa][NFC] Migrate Tosa dialect to the new fold API
See https://discourse.llvm.org/t/psa-new-improved-fold-method-signature-has-landed-please-update-your-downstream-projects/67618 for context
Differential Revision: https://reviews.llvm.org/D141527
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 110a334011021..1960e47632eb7 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -45,6 +45,7 @@ def Tosa_Dialect : Dialect {
let cppNamespace = "mlir::tosa";
let hasConstantMaterializer = 1;
let useDefaultAttributePrinterParser = 1;
+ let useFoldAPI = kEmitFoldAdaptorFolder;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 5f44634d5482d..625c85593a9cf 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -539,7 +539,7 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
return {};
}
-OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();
auto resultTy = getType().dyn_cast<RankedTensorType>();
@@ -549,8 +549,8 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
return {};
auto resultETy = resultTy.getElementType();
- auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
+ auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<FloatType>()) {
if (lhsAttr.getSplatValue<APFloat>().isZero())
@@ -579,7 +579,7 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
lhsTy);
}
-OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();
auto resultTy = getType().dyn_cast<RankedTensorType>();
@@ -589,8 +589,8 @@ OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
return {};
auto resultETy = resultTy.getElementType();
- auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
+ auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
if (lhsAttr && lhsAttr.isSplat()) {
if (resultETy.isa<IntegerType>() && lhsAttr.getSplatValue<APInt>().isZero())
return lhsAttr;
@@ -646,7 +646,7 @@ DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
}
} // namespace
-OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
auto lhs = getInput1();
auto rhs = getInput2();
auto lhsTy = lhs.getType().dyn_cast<RankedTensorType>();
@@ -658,8 +658,8 @@ OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
return {};
auto resultETy = resultTy.getElementType();
- auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
+ auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<FloatType>()) {
auto val = lhsAttr.getSplatValue<APFloat>();
@@ -700,7 +700,7 @@ OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
return mulBinaryFolder(lhsAttr, rhsAttr, lhsTy, getShift());
}
-OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();
auto resultTy = getType().dyn_cast<RankedTensorType>();
@@ -710,8 +710,8 @@ OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
return {};
auto resultETy = resultTy.getElementType();
- auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
+ auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
if (rhsAttr.getSplatValue<APFloat>().isZero())
@@ -757,10 +757,10 @@ struct APIntFoldGreaterEqual {
};
} // namespace
-OpFoldResult GreaterOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
auto resultTy = getType().dyn_cast<RankedTensorType>();
- auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
+ auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
if (!lhsAttr || !rhsAttr)
return {};
@@ -769,10 +769,10 @@ OpFoldResult GreaterOp::fold(ArrayRef<Attribute> operands) {
lhsAttr, rhsAttr, resultTy);
}
-OpFoldResult GreaterEqualOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
auto resultTy = getType().dyn_cast<RankedTensorType>();
- auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
+ auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
if (!lhsAttr || !rhsAttr)
return {};
@@ -782,10 +782,10 @@ OpFoldResult GreaterEqualOp::fold(ArrayRef<Attribute> operands) {
lhsAttr, rhsAttr, resultTy);
}
-OpFoldResult EqualOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
auto resultTy = getType().dyn_cast<RankedTensorType>();
- auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
+ auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
Value lhs = getInput1();
Value rhs = getInput2();
auto lhsTy = lhs.getType().cast<ShapedType>();
@@ -805,11 +805,11 @@ OpFoldResult EqualOp::fold(ArrayRef<Attribute> operands) {
resultTy);
}
-OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
if (getInput().getType() == getType())
return getInput();
- auto operand = operands[0].dyn_cast_or_null<ElementsAttr>();
+ auto operand = adaptor.getInput().dyn_cast_or_null<ElementsAttr>();
if (!operand)
return {};
@@ -868,13 +868,10 @@ OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
return {};
}
-OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.empty() && "constant has no operands");
- return getValueAttr();
-}
+OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
#define REDUCE_FOLDER(OP) \
- OpFoldResult OP::fold(ArrayRef<Attribute> operands) { \
+ OpFoldResult OP::fold(FoldAdaptor adaptor) { \
ShapedType inputTy = getInput().getType().cast<ShapedType>(); \
if (!inputTy.hasRank()) \
return {}; \
@@ -891,7 +888,7 @@ REDUCE_FOLDER(ReduceProdOp)
REDUCE_FOLDER(ReduceSumOp)
#undef REDUCE_FOLDER
-OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
auto inputTy = getInput1().getType().dyn_cast<RankedTensorType>();
auto outputTy = getType().dyn_cast<RankedTensorType>();
@@ -901,7 +898,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
if (inputTy == outputTy)
return getInput1();
- auto operand = operands[0].dyn_cast_or_null<DenseElementsAttr>();
+ auto operand = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
if (operand && outputTy.hasStaticShape() && operand.isSplat()) {
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
}
@@ -909,10 +906,10 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
return {};
}
-OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
// If the pad is all zeros we can fold this operation away.
- if (operands[1]) {
- auto densePad = operands[1].cast<DenseElementsAttr>();
+ if (adaptor.getPadding()) {
+ auto densePad = adaptor.getPadding().cast<DenseElementsAttr>();
if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
return getInput1();
}
@@ -923,7 +920,7 @@ OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
// Fold away cases where a tosa.resize operation returns a copy
// of the input image.
-OpFoldResult ResizeOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
ArrayRef<int64_t> offset = getOffset();
ArrayRef<int64_t> border = getBorder();
ArrayRef<int64_t> scale = getScale();
@@ -952,11 +949,11 @@ OpFoldResult ResizeOp::fold(ArrayRef<Attribute> operands) {
return input;
}
-OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
auto operand = getInput();
auto operandTy = operand.getType().cast<ShapedType>();
auto axis = getAxis();
- auto operandAttr = operands[0].dyn_cast_or_null<SplatElementsAttr>();
+ auto operandAttr = adaptor.getInput().dyn_cast_or_null<SplatElementsAttr>();
if (operandAttr)
return operandAttr;
@@ -967,7 +964,7 @@ OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
return {};
}
-OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
auto inputTy = getInput().getType().dyn_cast<RankedTensorType>();
auto outputTy = getType().dyn_cast<RankedTensorType>();
@@ -977,10 +974,10 @@ OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
if (inputTy == outputTy && inputTy.hasStaticShape())
return getInput();
- if (!operands[0])
+ if (!adaptor.getInput())
return {};
- auto operand = operands[0].cast<ElementsAttr>();
+ auto operand = adaptor.getInput().cast<ElementsAttr>();
if (operand.isSplat() && outputTy.hasStaticShape()) {
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
}
@@ -995,11 +992,11 @@ OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
return {};
}
-OpFoldResult tosa::SelectOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
if (getOnTrue() == getOnFalse())
return getOnTrue();
- auto predicate = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
+ auto predicate = adaptor.getPred().dyn_cast_or_null<DenseIntElementsAttr>();
if (!predicate)
return {};
@@ -1009,19 +1006,19 @@ OpFoldResult tosa::SelectOp::fold(ArrayRef<Attribute> operands) {
: getOnFalse();
}
-OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
if (allOnes && getInput1().getType() == getType())
return getInput1();
return {};
}
-OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
auto inputTy = getInput1().getType().cast<ShapedType>();
auto resultTy = getType().cast<ShapedType>();
// Transposing splat values just means reshaping.
- if (auto input = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
+ if (auto input = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>()) {
if (input.isSplat() && resultTy.hasStaticShape() &&
inputTy.getElementType() == resultTy.getElementType())
return input.reshape(resultTy);
More information about the Mlir-commits
mailing list