[Mlir-commits] [mlir] 1e7a6c0 - [mlir][linalg] Add getIteratorTypesArray to LinalgInterface.
Oleg Shyshkov
llvmlistbot at llvm.org
Tue Sep 27 07:31:56 PDT 2022
Author: Oleg Shyshkov
Date: 2022-09-27T14:30:50Z
New Revision: 1e7a6c0874f57014132b857b2d201d6aaa75feee
URL: https://github.com/llvm/llvm-project/commit/1e7a6c0874f57014132b857b2d201d6aaa75feee
DIFF: https://github.com/llvm/llvm-project/commit/1e7a6c0874f57014132b857b2d201d6aaa75feee.diff
LOG: [mlir][linalg] Add getIteratorTypesArray to LinalgInterface.
Summary:
Most of the code that gets `iterator_types` from LinalgInterface is forced to
extract values from an `Attribute`. As a result, the usage pattern looks like
this:
```
SmallVector<StringRef> iterators = llvm::to_vector<4>(linalgOp.iterator_types().getAsValueRange<StringAttr>());
```
It also forces all operations that implement LinalgOp interface to have
`iterator_types` attribute even when the information can be easily infered from
other parameters.
In perfect future, `getIteratorTypeArray` should be the only method to get
iterator types from the interface. The default implementation can rely on
`iterator_types` attribute though.
The name `getIteratorTypeArray` was picked to be consistent with existing
`getIndexingMapsArray`.
This patch add a few sample usages. More cleanups will follow.
Differential Revision: https://reviews.llvm.org/D134729
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 64c2bd1b28510..38031704dbf5f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -491,6 +491,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
return $_op.iterator_types();
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return iterator types in the current operation.
+ }],
+ /*retTy=*/"SmallVector<StringRef>",
+ /*methodName=*/"getIteratorTypesArray",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto range = $_op.iterator_types().template getAsValueRange<StringAttr>();
+ return {range.begin(), range.end()};
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Return true if the indexing map is depending on the current op instance.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 6e4c2fc9d7393..032578fc5e25b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -297,8 +297,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
!indexingMaps.back().isProjectedPermutation())
return MatchConvolutionResult::NotProjectedPermutations;
- auto iteratorTypesRange =
- linalgOp.iterator_types().getAsValueRange<StringAttr>();
+ auto iteratorTypesRange = linalgOp.getIteratorTypesArray();
llvm::SmallDenseSet<unsigned> outputDims =
getPreservedDims(indexingMaps.back());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index d58dc4a6540f5..15ba43cb98944 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -438,8 +438,7 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
GenericOp replacementOp = rewriter.create<GenericOp>(
loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
- llvm::to_vector<4>(genericOp.getIteratorTypes()
- .template getAsValueRange<StringAttr>()));
+ genericOp.getIteratorTypesArray());
rewriter.inlineRegionBefore(genericOp.getRegion(),
replacementOp.getRegion(),
replacementOp.getRegion().begin());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 97dd0c448403f..94526e140298c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -467,9 +467,8 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
.isProjectedPermutation();
}) &&
genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
- llvm::all_of(genericOp.getIteratorTypes(), [](Attribute attr) {
- return attr.cast<StringAttr>().getValue() ==
- getParallelIteratorTypeName();
+ llvm::all_of(genericOp.getIteratorTypesArray(), [](StringRef it) {
+ return it == getParallelIteratorTypeName();
});
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 4f384c33e9945..97eee8b3e5091 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -53,8 +53,7 @@ FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
SmallVector<Value> inputOperands = linalgOp.getInputOperands();
SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
- SmallVector<StringRef> iterators = llvm::to_vector<4>(
- linalgOp.iterator_types().getAsValueRange<StringAttr>());
+ SmallVector<StringRef> iterators = linalgOp.getIteratorTypesArray();
SmallVector<RankedTensorType> resultTypes = linalgOp.getOutputTensorTypes();
SmallVector<Type> types(resultTypes.begin(), resultTypes.end());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
index e8f5372e4fdc0..a53e7c5ff086e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
@@ -61,9 +61,7 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
auto newOp = rewriter.create<GenericOp>(
loc, genericOp->getResultTypes(), newOperands, outputOperands,
- newIndexingMaps,
- llvm::to_vector<4>(genericOp.getIteratorTypes()
- .template getAsValueRange<StringAttr>()));
+ newIndexingMaps, genericOp.getIteratorTypesArray());
rewriter.cloneRegionBefore(genericOp.getRegion(), newOp.getRegion(),
newOp.getRegion().begin());
More information about the Mlir-commits
mailing list