[Mlir-commits] [mlir] c54bc8b - [mlir][linalg] Use getIteratorTypeArray instead of raw iterator_type attribute.
Oleg Shyshkov
llvmlistbot at llvm.org
Fri Sep 30 09:04:02 PDT 2022
Author: Oleg Shyshkov
Date: 2022-09-30T16:03:33Z
New Revision: c54bc8bd07dc38135b6800922e261d573dcac956
URL: https://github.com/llvm/llvm-project/commit/c54bc8bd07dc38135b6800922e261d573dcac956
DIFF: https://github.com/llvm/llvm-project/commit/c54bc8bd07dc38135b6800922e261d573dcac956.diff
LOG: [mlir][linalg] Use getIteratorTypeArray instead of raw iterator_type attribute.
Summary:
Also modify helper methods to take StringRefs instread of Attributes. It makes
the code cleaner and will help with future migration from StringRef to
utils::IteratorType ([RFC](https://discourse.llvm.org/t/rfc-enumattr-for-iterator-types-in-linalg/64535)).
Differential Revision: https://reviews.llvm.org/D134888
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index b72d0944ded67..c989da01238a4 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -194,7 +194,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
return getNumIterators(getParallelIteratorTypeName(),
- $_op.iterator_types());
+ $_op.getIteratorTypesArray());
}]
>,
InterfaceMethod<
@@ -206,7 +206,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins "SmallVectorImpl<unsigned> &":$res),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return findPositionsOfType($_op.iterator_types(),
+ return findPositionsOfType($_op.getIteratorTypesArray(),
getParallelIteratorTypeName(), res);
}]
>,
@@ -220,7 +220,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
return getNumIterators(getReductionIteratorTypeName(),
- $_op.iterator_types());
+ $_op.getIteratorTypesArray());
}]
>,
InterfaceMethod<
@@ -232,7 +232,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins "SmallVectorImpl<unsigned> &":$res),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return findPositionsOfType($_op.iterator_types(),
+ return findPositionsOfType($_op.getIteratorTypesArray(),
getReductionIteratorTypeName(), res);
}]
>,
@@ -246,7 +246,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
return getNumIterators(getWindowIteratorTypeName(),
- $_op.iterator_types());
+ $_op.getIteratorTypesArray());
}]
>,
InterfaceMethod<
@@ -258,7 +258,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins "SmallVectorImpl<unsigned> &":$res),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return findPositionsOfType($_op.iterator_types(),
+ return findPositionsOfType($_op.getIteratorTypesArray(),
getWindowIteratorTypeName(), res);
}]
>,
@@ -271,7 +271,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return getNumIterators($_op.iterator_types());
+ return getNumIterators($_op.getIteratorTypesArray());
}]
>,
InterfaceMethod<
@@ -284,7 +284,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- auto iters = $_op.iterator_types();
+ auto iters = $_op.getIteratorTypesArray();
return iters.size() == 1 &&
getNumIterators(getReductionIteratorTypeName(), iters) == 1;
}]>,
@@ -759,7 +759,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
ArrayAttr getIteratorTypes() { return iterator_types(); }
SmallVector<StringRef> getIteratorTypeNames() {
- return llvm::to_vector(getIteratorTypes().getAsValueRange<StringAttr>());
+ return getIteratorTypesArray();
}
//========================================================================//
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index eae16cdc64441..3ec6fc7522d23 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -45,11 +45,11 @@ bool isElementwise(LinalgOp op);
/// `[0, permutation.size())`.
bool isPermutation(ArrayRef<int64_t> permutation);
-/// Check if `attr` has "parallel" iterator type semantics.
-bool isParallelIterator(Attribute attr);
+/// Check if iterator type has "parallel" semantics.
+bool isParallelIterator(StringRef iteratorType);
-/// Check if `attr` has "reduction" iterator type semantics.
-bool isReductionIterator(Attribute attr);
+/// Check if iterator type has "reduction" semantics.
+bool isReductionIterator(StringRef iteratorType);
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
@@ -488,7 +488,7 @@ struct RegionMatcher {
template <typename LoopTy>
struct GenerateLoopNest {
static void doit(OpBuilder &b, Location loc, ArrayRef<Range> loopRanges,
- LinalgOp linalgOp, ArrayRef<Attribute> iteratorTypes,
+ LinalgOp linalgOp, ArrayRef<StringRef> iteratorTypes,
function_ref<scf::ValueVector(OpBuilder &, Location,
ValueRange, ValueRange)>
bodyBuilderFn,
diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 8f7ac8cc2cee6..5086682ac60ee 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -97,16 +97,15 @@ inline ArrayRef<StringRef> getAllIteratorTypeNames() {
}
/// Returns the iterator of a certain type.
-inline unsigned getNumIterators(StringRef name, ArrayAttr iteratorTypes) {
+inline unsigned getNumIterators(StringRef name,
+ ArrayRef<StringRef> iteratorTypes) {
auto names = getAllIteratorTypeNames();
(void)names;
assert(llvm::is_contained(names, name));
- return llvm::count_if(iteratorTypes, [name](Attribute a) {
- return a.cast<StringAttr>().getValue() == name;
- });
+ return llvm::count(iteratorTypes, name);
}
-inline unsigned getNumIterators(ArrayAttr iteratorTypes) {
+inline unsigned getNumIterators(ArrayRef<StringRef> iteratorTypes) {
unsigned res = 0;
for (auto n : getAllIteratorTypeNames())
res += getNumIterators(n, iteratorTypes);
@@ -114,11 +113,10 @@ inline unsigned getNumIterators(ArrayAttr iteratorTypes) {
}
/// Return positions in `iteratorTypes` that match `iteratorTypeName`.
-inline void findPositionsOfType(ArrayAttr iteratorTypes,
+inline void findPositionsOfType(ArrayRef<StringRef> iteratorTypes,
StringRef iteratorTypeName,
SmallVectorImpl<unsigned> &res) {
- for (const auto &en :
- llvm::enumerate(iteratorTypes.getAsValueRange<StringAttr>())) {
+ for (const auto &en : llvm::enumerate(iteratorTypes)) {
if (en.value() == iteratorTypeName)
res.push_back(en.index());
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 8e7c797e34456..9a62a40c42be0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -297,7 +297,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
!indexingMaps.back().isProjectedPermutation())
return MatchConvolutionResult::NotProjectedPermutations;
- auto iteratorTypesRange = linalgOp.getIteratorTypesArray();
+ auto iteratorTypes = linalgOp.getIteratorTypesArray();
llvm::SmallDenseSet<unsigned> outputDims =
getPreservedDims(indexingMaps.back());
@@ -321,8 +321,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
if (inputExprWalker.unConvolvedDims.count(outputDim) &&
!filterDims.count(outputDim)) {
// Batch dimension.
- if (*std::next(iteratorTypesRange.begin(), outputDim) !=
- getParallelIteratorTypeName())
+ if (iteratorTypes[outputDim] != getParallelIteratorTypeName())
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
continue;
@@ -330,8 +329,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
if (inputExprWalker.convolvedDims.count(outputDim) &&
!filterDims.count(outputDim)) {
// Output image Loop dimension.
- if (*std::next(iteratorTypesRange.begin(), outputDim) !=
- getParallelIteratorTypeName())
+ if (iteratorTypes[outputDim] != getParallelIteratorTypeName())
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
continue;
@@ -340,8 +338,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
!inputExprWalker.unConvolvedDims.count(outputDim) &&
filterDims.count(outputDim)) {
// Output channel dimension.
- if (*std::next(iteratorTypesRange.begin(), outputDim) !=
- getParallelIteratorTypeName())
+ if (iteratorTypes[outputDim] != getParallelIteratorTypeName())
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
continue;
@@ -349,8 +346,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
if (inputExprWalker.unConvolvedDims.count(outputDim) &&
filterDims.count(outputDim)) {
// Depth multiplier.
- if (*std::next(iteratorTypesRange.begin(), outputDim) !=
- getParallelIteratorTypeName())
+ if (iteratorTypes[outputDim] != getParallelIteratorTypeName())
return MatchConvolutionResult::OutputDimsNotParallel;
allLoopDims.insert(outputDim);
continue;
@@ -368,8 +364,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
if (inputExprWalker.convolvedDims.count(filterDim) &&
!outputDims.count(filterDim)) {
// Filter loop dimension.
- if (*std::next(iteratorTypesRange.begin(), filterDim) !=
- getReductionIteratorTypeName())
+ if (iteratorTypes[filterDim] != getReductionIteratorTypeName())
return MatchConvolutionResult::NonOutputDimNotReduction;
if (allLoopDims.count(filterDim))
return MatchConvolutionResult::NonConvolutionLoop;
@@ -379,8 +374,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
if (inputExprWalker.unConvolvedDims.count(filterDim) &&
!outputDims.count(filterDim)) {
// Input channel dimension.
- if (*std::next(iteratorTypesRange.begin(), filterDim) !=
- getReductionIteratorTypeName())
+ if (iteratorTypes[filterDim] != getReductionIteratorTypeName())
return MatchConvolutionResult::NonOutputDimNotReduction;
if (allLoopDims.count(filterDim))
return MatchConvolutionResult::NonConvolutionLoop;
@@ -634,8 +628,7 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
LinalgOp linalgOp = cast<LinalgOp>(op);
// Check all iterator types are known.
- auto iteratorTypesRange =
- linalgOp.iterator_types().getAsValueRange<StringAttr>();
+ auto iteratorTypesRange = linalgOp.getIteratorTypesArray();
for (StringRef iteratorType : iteratorTypesRange) {
if (!llvm::is_contained(getAllIteratorTypeNames(), iteratorType) ||
!utils::symbolizeIteratorType(iteratorType).has_value())
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index eb2f2f1c3cee8..b6d1d21a66591 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1055,17 +1055,12 @@ getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
}
// Compute all the loops with the reduction iterator types.
- SmallVector<int64_t> reductionDims;
- for (const auto &iteratorType :
- llvm::enumerate(genericOp.getIteratorTypes())) {
- if (isReductionIterator(iteratorType.value())) {
- reductionDims.push_back(iteratorType.index());
- }
- }
+ SmallVector<unsigned> reductionDims;
+ genericOp.getReductionDims(reductionDims);
llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
- auto iteratorTypes = genericOp.getIteratorTypes().getValue();
+ auto iteratorTypes = genericOp.getIteratorTypesArray();
SmallVector<ReassociationIndices> iterationSpaceReassociation;
for (ReassociationIndicesRef foldedRangeDims : reassociation) {
assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
@@ -1085,7 +1080,7 @@ getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
continue;
// Check that all folded iterator types are all parallel or all reductions.
- Attribute startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]];
+ StringRef startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]];
if (!isParallelIterator(startIteratorType) &&
!isReductionIterator(startIteratorType))
continue;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index d4e3b52f30e3f..d29b767df9d71 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -433,12 +433,9 @@ FailureOr<TileLoopNest> mlir::linalg::tileConsumerAndFuseProducers(
// Search the number of outer parallel loops to separate them from possible
// inner reduction dimensions.
- SmallVector<StringAttr> iterTypes =
- llvm::to_vector<6>(consumerOp.iterator_types().getAsRange<StringAttr>());
+ SmallVector<StringRef> iterTypes = consumerOp.getIteratorTypesArray();
applyPermutationToVector(iterTypes, tileInterchange);
- auto *it = find_if(iterTypes, [&](StringAttr iterType) {
- return !isParallelIterator(iterType);
- });
+ auto *it = find_if_not(iterTypes, isParallelIterator);
int64_t split = std::distance(iterTypes.begin(), it);
// Helper to fuse the producers greedily using a queue of fusion candidates.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index a14994f1bd077..3052a4db29464 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -217,7 +217,7 @@ static FailureOr<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter,
"expected linalg op with buffer semantics");
auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
- auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue());
+ auto iteratorTypes = linalgOp.getIteratorTypesArray();
SmallVector<Value> allIvs;
GenerateLoopNest<LoopTy>::doit(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 56a7437d01f69..57153499a2da2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -215,10 +215,10 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
op.getContext()));
SmallVector<StringRef> newIteratorTypes;
- for (auto &it : llvm::enumerate(op.iterator_types())) {
+ for (auto &it : llvm::enumerate(op.getIteratorTypesArray())) {
if (insertSplitDimension == it.index() && !control.innerParallel)
newIteratorTypes.push_back(getParallelIteratorTypeName());
- newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue());
+ newIteratorTypes.push_back(it.value());
if (insertSplitDimension == it.index() && control.innerParallel)
newIteratorTypes.push_back(getParallelIteratorTypeName());
}
@@ -413,8 +413,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
// Step 4. Create the new op matching the original op with an extra parallel
// dimension.
- SmallVector<StringRef> iteratorTypes =
- llvm::to_vector<4>(op.getIteratorTypes().getAsValueRange<StringAttr>());
+ auto iteratorTypes = op.getIteratorTypesArray();
iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
getParallelIteratorTypeName());
GenericOp genericOp =
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index b0aa41eebc400..b389ceb2892d7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -427,9 +427,8 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges(
b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
- SmallVector<Attribute, 4> iteratorTypes;
- for (const auto &attr :
- enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) {
+ SmallVector<StringRef, 4> iteratorTypes;
+ for (const auto &attr : enumerate(op.getIteratorTypesArray())) {
if (loopIndexToRangeIndex.count(attr.index()))
iteratorTypes.push_back(attr.value());
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 46dc2324eff2d..0a117df6f9ab1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -190,14 +190,8 @@ static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
}
static SmallVector<bool> getReductionMask(LinalgOp linalgOp) {
- unsigned idx = 0;
- SmallVector<bool> reductionMask(linalgOp.iterator_types().size(), false);
- for (auto attr : linalgOp.iterator_types()) {
- if (isReductionIterator(attr))
- reductionMask[idx] = true;
- ++idx;
- }
- return reductionMask;
+ return llvm::to_vector(
+ llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
}
/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
@@ -540,7 +534,7 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
// TODO: probably need some extra checks for reduction followed by consumer
// ops that may not commute (e.g. linear reduction + non-linear instructions).
static LogicalResult reductionPreconditions(LinalgOp op) {
- if (llvm::none_of(op.iterator_types(), isReductionIterator)) {
+ if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
LDBG("reduction precondition failed: no reduction iterator");
return failure();
}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index d39fa11f364a2..999034b4e36b0 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -199,14 +199,12 @@ bool isPermutation(ArrayRef<int64_t> permutation) {
return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
}
-bool isParallelIterator(Attribute attr) {
- auto strAttr = attr.dyn_cast_or_null<StringAttr>();
- return strAttr && strAttr.getValue() == getParallelIteratorTypeName();
+bool isParallelIterator(StringRef iteratorType) {
+ return iteratorType == getParallelIteratorTypeName();
}
-bool isReductionIterator(Attribute attr) {
- auto strAttr = attr.dyn_cast_or_null<StringAttr>();
- return strAttr && strAttr.getValue() == getReductionIteratorTypeName();
+bool isReductionIterator(StringRef iteratorType) {
+ return iteratorType == getReductionIteratorTypeName();
}
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
@@ -484,7 +482,7 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
template <>
void GenerateLoopNest<scf::ForOp>::doit(
OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
- ArrayRef<Attribute> iteratorTypes,
+ ArrayRef<StringRef> iteratorTypes,
function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
ValueRange)>
bodyBuilderFn,
@@ -527,7 +525,7 @@ void GenerateLoopNest<scf::ForOp>::doit(
template <>
void GenerateLoopNest<AffineForOp>::doit(
OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
- ArrayRef<Attribute> iteratorTypes,
+ ArrayRef<StringRef> iteratorTypes,
function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
ValueRange)>
bodyBuilderFn,
@@ -577,7 +575,7 @@ void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId,
// exceeds 10.
static void generateParallelLoopNest(
OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
- ValueRange steps, ArrayRef<Attribute> iteratorTypes,
+ ValueRange steps, ArrayRef<StringRef> iteratorTypes,
ArrayRef<linalg::ProcInfo> procInfo,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
SmallVectorImpl<Value> &ivStorage) {
@@ -692,7 +690,7 @@ static void generateParallelLoopNest(
template <>
void GenerateLoopNest<scf::ParallelOp>::doit(
OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
- ArrayRef<Attribute> iteratorTypes,
+ ArrayRef<StringRef> iteratorTypes,
function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
ValueRange)>
bodyBuilderFn,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index ddcf839b1822e..88bd885393b7b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -214,7 +214,7 @@ static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
/// as we use adj matrix for the graph.
/// The sorted result will put the first Reduction iterator to the
/// latest possible index.
-static bool topSortOptimal(unsigned n, ArrayRef<Attribute> iteratorTypes,
+static bool topSortOptimal(unsigned n, ArrayRef<StringRef> iteratorTypes,
std::vector<unsigned> &topSort,
std::vector<unsigned> &inDegree,
std::vector<std::vector<bool>> &adjM) {
@@ -289,7 +289,7 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
unsigned n = op.getNumLoops();
std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
std::vector<unsigned> inDegree(n, 0); // in-degree of each node.
- auto iteratorTypes = op.iterator_types().getValue();
+ auto iteratorTypes = op.getIteratorTypesArray();
// Iterate over the indexing maps of every tensor in the tensor expression.
for (OpOperand *t : op.getInputAndOutputOperands()) {
// Skip tensor during cycle resolution.
@@ -361,7 +361,7 @@ static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op,
// An all-dense annotated "sparse" output tensor becomes a linearized random
// access 1-dim memref. Also admissible since insertions cannot occur.
bool allDense = true;
- auto iteratorTypes = op.iterator_types().getValue();
+ auto iteratorTypes = op.getIteratorTypesArray();
unsigned numLoops = iteratorTypes.size();
for (unsigned i = 0; i < numLoops; i++)
if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) ||
@@ -1299,7 +1299,7 @@ static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder,
unsigned fb = indices.find_first();
unsigned tensor = merger.tensor(fb);
assert(idx == merger.index(fb));
- auto iteratorTypes = op.iterator_types().getValue();
+ auto iteratorTypes = op.getIteratorTypesArray();
bool isReduction = linalg::isReductionIterator(iteratorTypes[idx]);
bool isSparse = merger.isDimLevelType(fb, DimLvlType::kCompressed) ||
merger.isDimLevelType(fb, DimLvlType::kSingleton);
More information about the Mlir-commits
mailing list