[Mlir-commits] [mlir] c38d9cf - [mlir] Remove iterator_types() method from LinalgStructuredInterface.
Oleg Shyshkov
llvmlistbot at llvm.org
Thu Oct 13 00:53:11 PDT 2022
Author: Oleg Shyshkov
Date: 2022-10-13T07:52:43Z
New Revision: c38d9cf20e7468a2618dc23fcdc66e79c925aff5
URL: https://github.com/llvm/llvm-project/commit/c38d9cf20e7468a2618dc23fcdc66e79c925aff5
DIFF: https://github.com/llvm/llvm-project/commit/c38d9cf20e7468a2618dc23fcdc66e79c925aff5.diff
LOG: [mlir] Remove iterator_types() method from LinalgStructuredInterface.
`getIteratorTypesArray` should be used instead. It's a better substitute for all the current usages of the interface.
The current `ArrayAttr iterator_types()` has a few problems:
* It creates an assumption operation has iterators types as an attribute, but it's not always the case. Sometime iterator types can be inferred from other attribute, or they're just static.
* ArrayAttr is an obscure contained and required extracting values in the client code.
* Makes it hard to migrate iterator types from strings to enums ([RFC](https://discourse.llvm.org/t/rfc-enumattr-for-iterator-types-in-linalg/64535/9)).
Concrete ops, like `linalg.generic` will still have iterator types as an attribute if needed.
As a side effect, this change helps a bit with migration to prefixed accessors.
Differential Revision: https://reviews.llvm.org/D135765
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 268587e312077..69871faa5a2cc 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -497,28 +497,21 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
return $_op.getBody();
}]
>,
- InterfaceMethod<
- /*desc=*/[{
- Return the iterator types attribute within the current operation.
- }],
- /*retTy=*/"ArrayAttr",
- /*methodName=*/"iterator_types",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- return $_op.getIteratorTypes();
- }]
- >,
InterfaceMethod<
/*desc=*/[{
Return iterator types in the current operation.
+
+ Default implementation assumes that the operation has an attribute
+ `iterator_types`, but it's not always the case. Sometimes iterator types
+ can be infered from other parameters and in such cases default
+ getIteratorTypesArray should be overriden.
}],
/*retTy=*/"SmallVector<StringRef>",
/*methodName=*/"getIteratorTypesArray",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- auto range = $_op.iterator_types().template getAsValueRange<StringAttr>();
+ auto range = $_op.getIteratorTypes().template getAsValueRange<StringAttr>();
return {range.begin(), range.end()};
}]
>,
@@ -773,9 +766,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
LogicalResult reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes);
- // TODO: Remove once prefixing is flipped.
- ArrayAttr getIteratorTypes() { return iterator_types(); }
-
SmallVector<StringRef> getIteratorTypeNames() {
return getIteratorTypesArray();
}
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 6bcf509844cb8..2619ad1186408 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -264,7 +264,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
let extraClassDeclaration = structuredOpsBaseDecls # [{
// Implement functions necessary for LinalgStructuredInterface.
- ArrayAttr getIteratorTypes();
+ SmallVector<StringRef> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
std::string getLibraryCallName() {
return "op_has_no_registered_library_name";
@@ -334,7 +334,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
let extraClassDeclaration = structuredOpsBaseDecls # [{
// Declare functions necessary for LinalgStructuredInterface.
- ArrayAttr getIteratorTypes();
+ SmallVector<StringRef> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
std::string getLibraryCallName() {
return "op_has_no_registered_library_name";
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 61b938601e103..c2705a383550a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1393,11 +1393,9 @@ LogicalResult MapOp::verify() {
return success();
}
-ArrayAttr MapOp::getIteratorTypes() {
+SmallVector<StringRef> MapOp::getIteratorTypesArray() {
int64_t rank = getInit().getType().getRank();
- return Builder(getContext())
- .getStrArrayAttr(
- SmallVector<StringRef>(rank, getParallelIteratorTypeName()));
+ return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
}
ArrayAttr MapOp::getIndexingMaps() {
@@ -1435,13 +1433,13 @@ void ReduceOp::getAsmResultNames(
setNameFn(getResults().front(), "reduced");
}
-ArrayAttr ReduceOp::getIteratorTypes() {
+SmallVector<StringRef> ReduceOp::getIteratorTypesArray() {
int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
SmallVector<StringRef> iteratorTypes(inputRank,
getParallelIteratorTypeName());
for (int64_t reductionDim : getDimensions())
iteratorTypes[reductionDim] = getReductionIteratorTypeName();
- return Builder(getContext()).getStrArrayAttr(iteratorTypes);
+ return iteratorTypes;
}
ArrayAttr ReduceOp::getIndexingMaps() {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index a9113f0d05713..66d55dcf5c713 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -92,11 +92,9 @@ struct LinalgOpTilingInterface
/// Return the loop iterator type.
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
- return llvm::to_vector(
- llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) {
- return utils::symbolizeIteratorType(
- strAttr.cast<StringAttr>().getValue())
- .value();
+ return llvm::to_vector(llvm::map_range(
+ concreteOp.getIteratorTypesArray(), [](StringRef iteratorType) {
+ return utils::symbolizeIteratorType(iteratorType).value();
}));
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index a28aa3b72ff43..2458dabef56c9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -250,7 +250,7 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
// Fuse producer and consumer into a new generic op.
auto fusedOp = rewriter.create<GenericOp>(
loc, op.getResult(0).getType(), inputOps, outputOps,
- rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.iterator_types(),
+ rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(),
/*doc=*/nullptr, /*library_call=*/nullptr);
Block &prodBlock = prod.getRegion().front();
Block &consBlock = op.getRegion().front();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 21ee7b5119763..1418ed4da4a4a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1857,7 +1857,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
if (op.getNumOutputs() != 1)
return failure();
unsigned numTensors = op.getNumInputsAndOutputs();
- unsigned numLoops = op.iterator_types().getValue().size();
+ unsigned numLoops = op.getNumLoops();
Merger merger(numTensors, numLoops);
if (!findSparseAnnotations(merger, op))
return failure();
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 1fca43ddff411..7f329126b2892 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2816,7 +2816,7 @@ def TestLinalgConvOp :
return ®ionBuilder;
}
- mlir::ArrayAttr iterator_types() {
+ mlir::ArrayAttr getIteratorTypes() {
return getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
}
@@ -2875,7 +2875,7 @@ def TestLinalgFillOp :
return ®ionBuilder;
}
- mlir::ArrayAttr iterator_types() {
+ mlir::ArrayAttr getIteratorTypes() {
return getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
}
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index 205cde22c9970..51557f519772a 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -235,7 +235,7 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: value
-# IMPL: Test3Op::iterator_types() {
+# IMPL: Test3Op::getIteratorTypesArray() {
# IMPL-NEXT: int64_t rank = getRank(getOutputOperand(0));
# IMPL: Test3Op::getIndexingMaps() {
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 1bea98fe71bfb..8156bb97a32f3 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -553,7 +553,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
let extraClassDeclaration = structuredOpsBaseDecls # [{{
// Auto-generated.
- ArrayAttr iterator_types();
+ SmallVector<StringRef> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);
@@ -587,24 +587,24 @@ static const char structuredOpBuilderFormat[] = R"FMT(
}]>
)FMT";
-// The iterator_types() method for structured ops. Parameters:
+// The getIteratorTypesArray() method for structured ops. Parameters:
// {0}: Class name
// {1}: Comma interleaved iterator type names.
static const char structuredOpIteratorTypesFormat[] =
R"FMT(
-ArrayAttr {0}::iterator_types() {{
- return Builder(getContext()).getStrArrayAttr(SmallVector<StringRef>{{ {1} });
+SmallVector<StringRef> {0}::getIteratorTypesArray() {{
+ return SmallVector<StringRef>{{ {1} };
}
)FMT";
-// The iterator_types() method for rank polymorphic structured ops. Parameters:
+// The getIteratorTypesArray() method for rank polymorphic structured ops.
+// Parameters:
// {0}: Class name
static const char rankPolyStructuredOpIteratorTypesFormat[] =
R"FMT(
-ArrayAttr {0}::iterator_types() {{
+SmallVector<StringRef> {0}::getIteratorTypesArray() {{
int64_t rank = getRank(getOutputOperand(0));
- return Builder(getContext()).getStrArrayAttr(
- SmallVector<StringRef>(rank, getParallelIteratorTypeName()));
+ return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
}
)FMT";
More information about the Mlir-commits
mailing list