[Mlir-commits] [mlir] c21e88c - [mlir][Tensor] Avoid dropping attributes for `tensor.pad` operations during canonicalization.
Mahesh Ravishankar
llvmlistbot at llvm.org
Mon Mar 20 14:04:00 PDT 2023
Author: Mahesh Ravishankar
Date: 2023-03-20T21:03:46Z
New Revision: c21e88cc02617e0f04807a8dcf164b405d67d5e4
URL: https://github.com/llvm/llvm-project/commit/c21e88cc02617e0f04807a8dcf164b405d67d5e4
DIFF: https://github.com/llvm/llvm-project/commit/c21e88cc02617e0f04807a8dcf164b405d67d5e4.diff
LOG: [mlir][Tensor] Avoid dropping attributes for `tensor.pad` operations during canonicalization.
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D146440
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index cc8bbd570ef66..3c3fa70e161f7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "llvm/ADT/StringSet.h"
#include <optional>
@@ -461,18 +462,10 @@ struct GenerateLoopNest {
/// Returns an attribute list that excludes pre-defined attributes.
template <typename OpTy>
SmallVector<NamedAttribute> getPrunedAttributeList(OpTy op) {
- llvm::StringSet<> elidedAttrs;
- elidedAttrs.insert(op.getAttributeNames().begin(),
- op.getAttributeNames().end());
+ auto elidedAttrs = llvm::to_vector(op.getAttributeNames());
if (isa<linalg::LinalgOp>(op.getOperation()))
- elidedAttrs.insert(LinalgDialect::kMemoizedIndexingMapsAttrName);
- SmallVector<NamedAttribute> attrs;
- for (auto attr : op->getAttrs()) {
- if (elidedAttrs.count(attr.getName()))
- continue;
- attrs.push_back(attr);
- }
- return attrs;
+ elidedAttrs.push_back(LinalgDialect::kMemoizedIndexingMapsAttrName);
+ return getPrunedAttributeList(op, elidedAttrs);
}
} // namespace linalg
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 09b7775dcaae4..66d6dcc7b27ed 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1295,13 +1295,13 @@ def Tensor_PadOp : Tensor_Op<"pad", [
let builders = [
// Build a PadOp with mixed static and dynamic entries.
- OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$staticLow,
- "ArrayRef<int64_t>":$staticHigh, "ValueRange":$low, "ValueRange":$high,
- CArg<"bool", "false">:$nofold,
+ OpBuilder<(ins "Type":$resultType, "Value":$source,
+ "ArrayRef<int64_t>":$staticLow, "ArrayRef<int64_t>":$staticHigh,
+ "ValueRange":$low, "ValueRange":$high, CArg<"bool", "false">:$nofold,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build a PadOp with all dynamic entries.
- OpBuilder<(ins "Value":$source, "ValueRange":$low, "ValueRange":$high,
- CArg<"bool", "false">:$nofold,
+ OpBuilder<(ins "Type":$resultType, "Value":$source, "ValueRange":$low,
+ "ValueRange":$high, CArg<"bool", "false">:$nofold,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build a PadOp with mixed static and dynamic entries and custom
// result type. If the type passed is nullptr, it is inferred.
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 1297e87714f79..c4f9fa8a6fe05 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -123,6 +123,11 @@ Operation *cloneWithoutRegions(OpBuilder &b, Operation *op,
TypeRange newResultTypes,
ValueRange newOperands);
+// Get the list of attributes associated with the op, ignoring
+// those with the provided name.
+SmallVector<NamedAttribute>
+getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs);
+
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f2da1088eb04d..9d26e51e04fd5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2518,26 +2518,27 @@ RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
return RankedTensorType::get(inferredShape, sourceType.getElementType());
}
-void PadOp::build(OpBuilder &b, OperationState &result, Value source,
- ArrayRef<int64_t> staticLow, ArrayRef<int64_t> staticHigh,
- ValueRange low, ValueRange high, bool nofold,
- ArrayRef<NamedAttribute> attrs) {
+void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
+ Value source, ArrayRef<int64_t> staticLow,
+ ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
+ bool nofold, ArrayRef<NamedAttribute> attrs) {
auto sourceType = source.getType().cast<RankedTensorType>();
- auto resultType = inferResultType(sourceType, staticLow, staticHigh);
+ if (!resultType)
+ resultType = inferResultType(sourceType, staticLow, staticHigh);
build(b, result, resultType, source, low, high,
b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
nofold ? b.getUnitAttr() : UnitAttr());
result.addAttributes(attrs);
}
-void PadOp::build(OpBuilder &b, OperationState &result, Value source,
- ValueRange low, ValueRange high, bool nofold,
+void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
+ Value source, ValueRange low, ValueRange high, bool nofold,
ArrayRef<NamedAttribute> attrs) {
auto sourceType = source.getType().cast<RankedTensorType>();
unsigned rank = sourceType.getRank();
SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
- build(b, result, source, staticVector, staticVector, low, high, nofold,
- attrs);
+ build(b, result, resultType, source, staticVector, staticVector, low, high,
+ nofold, attrs);
}
void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
@@ -2635,9 +2636,9 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
} else {
auto newOp = rewriter.create<PadOp>(
padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
- padTensorOp.getLow(), padTensorOp.getHigh(),
padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
- padTensorOp.getNofold());
+ padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
+ getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
IRMapping mapper;
padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
@@ -2667,9 +2668,10 @@ struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
auto replacementOp = rewriter.create<PadOp>(
padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
- padTensorOp.getSource(), padTensorOp.getLow(), padTensorOp.getHigh(),
- padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
- padTensorOp.getNofold());
+ padTensorOp.getSource(), padTensorOp.getStaticLow(),
+ padTensorOp.getStaticHigh(), padTensorOp.getLow(),
+ padTensorOp.getHigh(), padTensorOp.getNofold(),
+ getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
replacementOp.getRegion().takeBody(padTensorOp.getRegion());
rewriter.replaceOp(padTensorOp, replacementOp.getResult());
@@ -2827,7 +2829,8 @@ struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
innerSliceOp.getMixedStrides());
auto newPadOp = rewriter.create<PadOp>(
padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
- padOp.getMixedLowPad(), newHighPad, padOp.getNofold());
+ padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
+ getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
newPadOp.getRegion().begin());
rewriter.replaceOp(padOp, newPadOp.getResult());
@@ -2916,8 +2919,9 @@ struct FoldStaticPadding : public OpRewritePattern<PadOp> {
auto newResultType = RankedTensorType::get(
newOutDims, padTensorOp.getType().getElementType());
auto newOp = rewriter.create<PadOp>(
- padTensorOp->getLoc(), newResultType, input, padTensorOp.getLow(),
- padTensorOp.getHigh(), staticLow, staticHigh, padTensorOp.getNofold());
+ padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
+ padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
+ getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
IRMapping mapper;
padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
index b22f42c09da59..49b49ef639708 100644
--- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
@@ -11,6 +11,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
+#include "llvm/ADT/StringSet.h"
#include "mlir/Dialect/Utils/DialectUtilsEnums.cpp.inc"
@@ -114,3 +115,16 @@ Operation *mlir::cloneWithoutRegions(OpBuilder &b, Operation *op,
state.addRegion();
return b.create(state);
}
+
+SmallVector<NamedAttribute>
+mlir::getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs) {
+ llvm::StringSet elidedAttrsSet;
+ elidedAttrsSet.insert(elidedAttrs.begin(), elidedAttrs.end());
+ SmallVector<NamedAttribute> attrs;
+ for (auto attr : op->getAttrs()) {
+ if (elidedAttrsSet.count(attr.getName()))
+ continue;
+ attrs.push_back(attr);
+ }
+ return attrs;
+}
More information about the Mlir-commits
mailing list