[Mlir-commits] [mlir] 4e39007 - [mlir][Tensor][NFC] Migrate Tensor dialect to the new fold API
Markus Böck
llvmlistbot at llvm.org
Tue Jan 17 04:25:45 PST 2023
Author: Markus Böck
Date: 2023-01-17T13:22:11+01:00
New Revision: 4e390073aa9ac84651f031414b5af3eeefd20b14
URL: https://github.com/llvm/llvm-project/commit/4e390073aa9ac84651f031414b5af3eeefd20b14
DIFF: https://github.com/llvm/llvm-project/commit/4e390073aa9ac84651f031414b5af3eeefd20b14.diff
LOG: [mlir][Tensor][NFC] Migrate Tensor dialect to the new fold API
See https://discourse.llvm.org/t/psa-new-improved-fold-method-signature-has-landed-please-update-your-downstream-projects/67618 for context
Differential Revision: https://reviews.llvm.org/D141530
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td
index fe49f8db9810d..b27b6ea064c5a 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td
@@ -47,6 +47,7 @@ def Tensor_Dialect : Dialect {
let hasCanonicalizer = 1;
let hasConstantMaterializer = 1;
+ let useFoldAPI = kEmitFoldAdaptorFolder;
let dependentDialects = [
"AffineDialect",
"arith::ArithDialect",
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 18ffbe32703dd..d8c337d32e36c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -418,9 +418,9 @@ LogicalResult DimOp::verify() {
return success();
}
-OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
// All forms of folding require a known index.
- auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
+ auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
if (!index)
return {};
@@ -763,16 +763,16 @@ LogicalResult ExtractOp::verify() {
return success();
}
-OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// If this is a splat elements attribute, simply return the value. All of
// the elements of a splat attribute are the same.
- if (Attribute tensor = operands.front())
+ if (Attribute tensor = adaptor.getTensor())
if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
return splatTensor.getSplatValue<Attribute>();
// Collect the constant indices into the tensor.
SmallVector<uint64_t, 8> indices;
- for (Attribute indice : llvm::drop_begin(operands, 1)) {
+ for (Attribute indice : adaptor.getIndices()) {
if (!indice || !indice.isa<IntegerAttr>())
return {};
indices.push_back(indice.cast<IntegerAttr>().getInt());
@@ -800,7 +800,7 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
}
// If this is an elements attribute, query the value at the given indices.
- if (Attribute tensor = operands.front()) {
+ if (Attribute tensor = adaptor.getTensor()) {
auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
if (elementsAttr && elementsAttr.isValidIndex(indices))
return elementsAttr.getValues<Attribute>()[indices];
@@ -837,9 +837,9 @@ void FromElementsOp::build(OpBuilder &builder, OperationState &result,
build(builder, result, resultType, elements);
}
-OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
- if (!llvm::is_contained(operands, nullptr))
- return DenseElementsAttr::get(getType(), operands);
+OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
+ if (!llvm::is_contained(adaptor.getElements(), nullptr))
+ return DenseElementsAttr::get(getType(), adaptor.getElements());
return {};
}
@@ -996,9 +996,9 @@ LogicalResult InsertOp::verify() {
return success();
}
-OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
- Attribute scalar = operands[0];
- Attribute dest = operands[1];
+OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
+ Attribute scalar = adaptor.getScalar();
+ Attribute dest = adaptor.getDest();
if (scalar && dest)
if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
if (scalar == splatDest.getSplatValue<Attribute>())
@@ -1178,7 +1178,7 @@ void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "rank");
}
-OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
+OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
// Constant fold rank when the rank of the operand is known.
auto type = getOperand().getType();
auto shapedType = type.dyn_cast<ShapedType>();
@@ -1558,12 +1558,14 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
context);
}
-OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
- return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
+OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
+ return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
+ adaptor.getOperands());
}
-OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
- return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
+OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
+ return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
+ adaptor.getOperands());
}
//===----------------------------------------------------------------------===//
@@ -2050,8 +2052,8 @@ static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
return {};
}
-OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
- if (auto splat = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
+OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
+ if (auto splat = adaptor.getSource().dyn_cast_or_null<SplatElementsAttr>()) {
auto resultType = getResult().getType().cast<ShapedType>();
if (resultType.hasStaticShape())
return splat.resizeSplat(resultType);
@@ -2197,7 +2199,7 @@ static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
return extractOp.getSource();
}
-OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
+OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
getSourceType() == getType() &&
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
@@ -2869,7 +2871,7 @@ Value PadOp::getConstantPaddingValue() {
return padValue;
}
-OpFoldResult PadOp::fold(ArrayRef<Attribute>) {
+OpFoldResult PadOp::fold(FoldAdaptor) {
if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
!getNofold())
return getSource();
@@ -3004,8 +3006,8 @@ void SplatOp::getAsmResultNames(
setNameFn(getResult(), "splat");
}
-OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
- auto constOperand = operands.front();
+OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
+ auto constOperand = adaptor.getInput();
if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
return {};
More information about the Mlir-commits
mailing list