[Mlir-commits] [mlir] c54bc8b - [mlir][linalg] Use getIteratorTypeArray instead of raw iterator_type attribute.

Oleg Shyshkov llvmlistbot at llvm.org
Fri Sep 30 09:04:02 PDT 2022


Author: Oleg Shyshkov
Date: 2022-09-30T16:03:33Z
New Revision: c54bc8bd07dc38135b6800922e261d573dcac956

URL: https://github.com/llvm/llvm-project/commit/c54bc8bd07dc38135b6800922e261d573dcac956
DIFF: https://github.com/llvm/llvm-project/commit/c54bc8bd07dc38135b6800922e261d573dcac956.diff

LOG: [mlir][linalg] Use getIteratorTypeArray instead of raw iterator_type attribute.

Summary:
Also modify helper methods to take StringRefs instread of Attributes. It makes
the code cleaner and will help with future migration from StringRef to
utils::IteratorType ([RFC](https://discourse.llvm.org/t/rfc-enumattr-for-iterator-types-in-linalg/64535)).

Differential Revision: https://reviews.llvm.org/D134888

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index b72d0944ded67..c989da01238a4 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -194,7 +194,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         return getNumIterators(getParallelIteratorTypeName(),
-                               $_op.iterator_types());
+                                $_op.getIteratorTypesArray());
       }]
     >,
     InterfaceMethod<
@@ -206,7 +206,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*args=*/(ins "SmallVectorImpl<unsigned> &":$res),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return findPositionsOfType($_op.iterator_types(),
+        return findPositionsOfType($_op.getIteratorTypesArray(),
                                    getParallelIteratorTypeName(), res);
       }]
     >,
@@ -220,7 +220,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         return getNumIterators(getReductionIteratorTypeName(),
-                               $_op.iterator_types());
+                               $_op.getIteratorTypesArray());
       }]
     >,
     InterfaceMethod<
@@ -232,7 +232,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*args=*/(ins "SmallVectorImpl<unsigned> &":$res),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return findPositionsOfType($_op.iterator_types(),
+        return findPositionsOfType($_op.getIteratorTypesArray(),
                                    getReductionIteratorTypeName(), res);
       }]
     >,
@@ -246,7 +246,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         return getNumIterators(getWindowIteratorTypeName(),
-                               $_op.iterator_types());
+                               $_op.getIteratorTypesArray());
       }]
     >,
     InterfaceMethod<
@@ -258,7 +258,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*args=*/(ins "SmallVectorImpl<unsigned> &":$res),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return findPositionsOfType($_op.iterator_types(),
+        return findPositionsOfType($_op.getIteratorTypesArray(),
                                    getWindowIteratorTypeName(), res);
       }]
     >,
@@ -271,7 +271,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return getNumIterators($_op.iterator_types());
+        return getNumIterators($_op.getIteratorTypesArray());
       }]
     >,
     InterfaceMethod<
@@ -284,7 +284,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        auto iters = $_op.iterator_types();
+        auto iters = $_op.getIteratorTypesArray();
         return iters.size() == 1 &&
                getNumIterators(getReductionIteratorTypeName(), iters) == 1;
       }]>,
@@ -759,7 +759,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     ArrayAttr getIteratorTypes() { return iterator_types(); }
 
     SmallVector<StringRef> getIteratorTypeNames() {
-      return llvm::to_vector(getIteratorTypes().getAsValueRange<StringAttr>());
+      return getIteratorTypesArray();
     }
 
     //========================================================================//

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index eae16cdc64441..3ec6fc7522d23 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -45,11 +45,11 @@ bool isElementwise(LinalgOp op);
 /// `[0, permutation.size())`.
 bool isPermutation(ArrayRef<int64_t> permutation);
 
-/// Check if `attr` has "parallel" iterator type semantics.
-bool isParallelIterator(Attribute attr);
+/// Check if iterator type has "parallel" semantics.
+bool isParallelIterator(StringRef iteratorType);
 
-/// Check if `attr` has "reduction" iterator type semantics.
-bool isReductionIterator(Attribute attr);
+/// Check if iterator type  has "reduction" semantics.
+bool isReductionIterator(StringRef iteratorType);
 
 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
 /// the type of `source`.
@@ -488,7 +488,7 @@ struct RegionMatcher {
 template <typename LoopTy>
 struct GenerateLoopNest {
   static void doit(OpBuilder &b, Location loc, ArrayRef<Range> loopRanges,
-                   LinalgOp linalgOp, ArrayRef<Attribute> iteratorTypes,
+                   LinalgOp linalgOp, ArrayRef<StringRef> iteratorTypes,
                    function_ref<scf::ValueVector(OpBuilder &, Location,
                                                  ValueRange, ValueRange)>
                        bodyBuilderFn,

diff  --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 8f7ac8cc2cee6..5086682ac60ee 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -97,16 +97,15 @@ inline ArrayRef<StringRef> getAllIteratorTypeNames() {
 }
 
 /// Returns the iterator of a certain type.
-inline unsigned getNumIterators(StringRef name, ArrayAttr iteratorTypes) {
+inline unsigned getNumIterators(StringRef name,
+                                ArrayRef<StringRef> iteratorTypes) {
   auto names = getAllIteratorTypeNames();
   (void)names;
   assert(llvm::is_contained(names, name));
-  return llvm::count_if(iteratorTypes, [name](Attribute a) {
-    return a.cast<StringAttr>().getValue() == name;
-  });
+  return llvm::count(iteratorTypes, name);
 }
 
-inline unsigned getNumIterators(ArrayAttr iteratorTypes) {
+inline unsigned getNumIterators(ArrayRef<StringRef> iteratorTypes) {
   unsigned res = 0;
   for (auto n : getAllIteratorTypeNames())
     res += getNumIterators(n, iteratorTypes);
@@ -114,11 +113,10 @@ inline unsigned getNumIterators(ArrayAttr iteratorTypes) {
 }
 
 /// Return positions in `iteratorTypes` that match `iteratorTypeName`.
-inline void findPositionsOfType(ArrayAttr iteratorTypes,
+inline void findPositionsOfType(ArrayRef<StringRef> iteratorTypes,
                                 StringRef iteratorTypeName,
                                 SmallVectorImpl<unsigned> &res) {
-  for (const auto &en :
-       llvm::enumerate(iteratorTypes.getAsValueRange<StringAttr>())) {
+  for (const auto &en : llvm::enumerate(iteratorTypes)) {
     if (en.value() == iteratorTypeName)
       res.push_back(en.index());
   }

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 8e7c797e34456..9a62a40c42be0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -297,7 +297,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
       !indexingMaps.back().isProjectedPermutation())
     return MatchConvolutionResult::NotProjectedPermutations;
 
-  auto iteratorTypesRange = linalgOp.getIteratorTypesArray();
+  auto iteratorTypes = linalgOp.getIteratorTypesArray();
 
   llvm::SmallDenseSet<unsigned> outputDims =
       getPreservedDims(indexingMaps.back());
@@ -321,8 +321,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
     if (inputExprWalker.unConvolvedDims.count(outputDim) &&
         !filterDims.count(outputDim)) {
       // Batch dimension.
-      if (*std::next(iteratorTypesRange.begin(), outputDim) !=
-          getParallelIteratorTypeName())
+      if (iteratorTypes[outputDim] != getParallelIteratorTypeName())
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
       continue;
@@ -330,8 +329,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
     if (inputExprWalker.convolvedDims.count(outputDim) &&
         !filterDims.count(outputDim)) {
       // Output image Loop dimension.
-      if (*std::next(iteratorTypesRange.begin(), outputDim) !=
-          getParallelIteratorTypeName())
+      if (iteratorTypes[outputDim] != getParallelIteratorTypeName())
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
       continue;
@@ -340,8 +338,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
         !inputExprWalker.unConvolvedDims.count(outputDim) &&
         filterDims.count(outputDim)) {
       // Output channel dimension.
-      if (*std::next(iteratorTypesRange.begin(), outputDim) !=
-          getParallelIteratorTypeName())
+      if (iteratorTypes[outputDim] != getParallelIteratorTypeName())
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
       continue;
@@ -349,8 +346,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
     if (inputExprWalker.unConvolvedDims.count(outputDim) &&
         filterDims.count(outputDim)) {
       // Depth multiplier.
-      if (*std::next(iteratorTypesRange.begin(), outputDim) !=
-          getParallelIteratorTypeName())
+      if (iteratorTypes[outputDim] != getParallelIteratorTypeName())
         return MatchConvolutionResult::OutputDimsNotParallel;
       allLoopDims.insert(outputDim);
       continue;
@@ -368,8 +364,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
     if (inputExprWalker.convolvedDims.count(filterDim) &&
         !outputDims.count(filterDim)) {
       // Filter loop dimension.
-      if (*std::next(iteratorTypesRange.begin(), filterDim) !=
-          getReductionIteratorTypeName())
+      if (iteratorTypes[filterDim] != getReductionIteratorTypeName())
         return MatchConvolutionResult::NonOutputDimNotReduction;
       if (allLoopDims.count(filterDim))
         return MatchConvolutionResult::NonConvolutionLoop;
@@ -379,8 +374,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
     if (inputExprWalker.unConvolvedDims.count(filterDim) &&
         !outputDims.count(filterDim)) {
       // Input channel dimension.
-      if (*std::next(iteratorTypesRange.begin(), filterDim) !=
-          getReductionIteratorTypeName())
+      if (iteratorTypes[filterDim] != getReductionIteratorTypeName())
         return MatchConvolutionResult::NonOutputDimNotReduction;
       if (allLoopDims.count(filterDim))
         return MatchConvolutionResult::NonConvolutionLoop;
@@ -634,8 +628,7 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   LinalgOp linalgOp = cast<LinalgOp>(op);
 
   // Check all iterator types are known.
-  auto iteratorTypesRange =
-      linalgOp.iterator_types().getAsValueRange<StringAttr>();
+  auto iteratorTypesRange = linalgOp.getIteratorTypesArray();
   for (StringRef iteratorType : iteratorTypesRange) {
     if (!llvm::is_contained(getAllIteratorTypeNames(), iteratorType) ||
         !utils::symbolizeIteratorType(iteratorType).has_value())

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index eb2f2f1c3cee8..b6d1d21a66591 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1055,17 +1055,12 @@ getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
   }
 
   // Compute all the loops with the reduction iterator types.
-  SmallVector<int64_t> reductionDims;
-  for (const auto &iteratorType :
-       llvm::enumerate(genericOp.getIteratorTypes())) {
-    if (isReductionIterator(iteratorType.value())) {
-      reductionDims.push_back(iteratorType.index());
-    }
-  }
+  SmallVector<unsigned> reductionDims;
+  genericOp.getReductionDims(reductionDims);
 
   llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
   AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
-  auto iteratorTypes = genericOp.getIteratorTypes().getValue();
+  auto iteratorTypes = genericOp.getIteratorTypesArray();
   SmallVector<ReassociationIndices> iterationSpaceReassociation;
   for (ReassociationIndicesRef foldedRangeDims : reassociation) {
     assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
@@ -1085,7 +1080,7 @@ getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
       continue;
 
     // Check that all folded iterator types are all parallel or all reductions.
-    Attribute startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]];
+    StringRef startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]];
     if (!isParallelIterator(startIteratorType) &&
         !isReductionIterator(startIteratorType))
       continue;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index d4e3b52f30e3f..d29b767df9d71 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -433,12 +433,9 @@ FailureOr<TileLoopNest> mlir::linalg::tileConsumerAndFuseProducers(
 
   // Search the number of outer parallel loops to separate them from possible
   // inner reduction dimensions.
-  SmallVector<StringAttr> iterTypes =
-      llvm::to_vector<6>(consumerOp.iterator_types().getAsRange<StringAttr>());
+  SmallVector<StringRef> iterTypes = consumerOp.getIteratorTypesArray();
   applyPermutationToVector(iterTypes, tileInterchange);
-  auto *it = find_if(iterTypes, [&](StringAttr iterType) {
-    return !isParallelIterator(iterType);
-  });
+  auto *it = find_if_not(iterTypes, isParallelIterator);
   int64_t split = std::distance(iterTypes.begin(), it);
 
   // Helper to fuse the producers greedily using a queue of fusion candidates.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index a14994f1bd077..3052a4db29464 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -217,7 +217,7 @@ static FailureOr<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter,
          "expected linalg op with buffer semantics");
 
   auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
-  auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue());
+  auto iteratorTypes = linalgOp.getIteratorTypesArray();
 
   SmallVector<Value> allIvs;
   GenerateLoopNest<LoopTy>::doit(

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 56a7437d01f69..57153499a2da2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -215,10 +215,10 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
   newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
                                    op.getContext()));
   SmallVector<StringRef> newIteratorTypes;
-  for (auto &it : llvm::enumerate(op.iterator_types())) {
+  for (auto &it : llvm::enumerate(op.getIteratorTypesArray())) {
     if (insertSplitDimension == it.index() && !control.innerParallel)
       newIteratorTypes.push_back(getParallelIteratorTypeName());
-    newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue());
+    newIteratorTypes.push_back(it.value());
     if (insertSplitDimension == it.index() && control.innerParallel)
       newIteratorTypes.push_back(getParallelIteratorTypeName());
   }
@@ -413,8 +413,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
 
   // Step 4. Create the new op matching the original op with an extra parallel
   // dimension.
-  SmallVector<StringRef> iteratorTypes =
-      llvm::to_vector<4>(op.getIteratorTypes().getAsValueRange<StringAttr>());
+  auto iteratorTypes = op.getIteratorTypesArray();
   iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
                        getParallelIteratorTypeName());
   GenericOp genericOp =

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index b0aa41eebc400..b389ceb2892d7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -427,9 +427,8 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
   auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges(
       b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
 
-  SmallVector<Attribute, 4> iteratorTypes;
-  for (const auto &attr :
-       enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) {
+  SmallVector<StringRef, 4> iteratorTypes;
+  for (const auto &attr : enumerate(op.getIteratorTypesArray())) {
     if (loopIndexToRangeIndex.count(attr.index()))
       iteratorTypes.push_back(attr.value());
   }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 46dc2324eff2d..0a117df6f9ab1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -190,14 +190,8 @@ static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
 }
 
 static SmallVector<bool> getReductionMask(LinalgOp linalgOp) {
-  unsigned idx = 0;
-  SmallVector<bool> reductionMask(linalgOp.iterator_types().size(), false);
-  for (auto attr : linalgOp.iterator_types()) {
-    if (isReductionIterator(attr))
-      reductionMask[idx] = true;
-    ++idx;
-  }
-  return reductionMask;
+  return llvm::to_vector(
+      llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
 }
 
 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
@@ -540,7 +534,7 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
 // TODO: probably need some extra checks for reduction followed by consumer
 // ops that may not commute (e.g. linear reduction + non-linear instructions).
 static LogicalResult reductionPreconditions(LinalgOp op) {
-  if (llvm::none_of(op.iterator_types(), isReductionIterator)) {
+  if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
     LDBG("reduction precondition failed: no reduction iterator");
     return failure();
   }

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index d39fa11f364a2..999034b4e36b0 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -199,14 +199,12 @@ bool isPermutation(ArrayRef<int64_t> permutation) {
   return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
 }
 
-bool isParallelIterator(Attribute attr) {
-  auto strAttr = attr.dyn_cast_or_null<StringAttr>();
-  return strAttr && strAttr.getValue() == getParallelIteratorTypeName();
+bool isParallelIterator(StringRef iteratorType) {
+  return iteratorType == getParallelIteratorTypeName();
 }
 
-bool isReductionIterator(Attribute attr) {
-  auto strAttr = attr.dyn_cast_or_null<StringAttr>();
-  return strAttr && strAttr.getValue() == getReductionIteratorTypeName();
+bool isReductionIterator(StringRef iteratorType) {
+  return iteratorType == getReductionIteratorTypeName();
 }
 
 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
@@ -484,7 +482,7 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
 template <>
 void GenerateLoopNest<scf::ForOp>::doit(
     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
-    ArrayRef<Attribute> iteratorTypes,
+    ArrayRef<StringRef> iteratorTypes,
     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
                                   ValueRange)>
         bodyBuilderFn,
@@ -527,7 +525,7 @@ void GenerateLoopNest<scf::ForOp>::doit(
 template <>
 void GenerateLoopNest<AffineForOp>::doit(
     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
-    ArrayRef<Attribute> iteratorTypes,
+    ArrayRef<StringRef> iteratorTypes,
     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
                                   ValueRange)>
         bodyBuilderFn,
@@ -577,7 +575,7 @@ void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId,
 // exceeds 10.
 static void generateParallelLoopNest(
     OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
-    ValueRange steps, ArrayRef<Attribute> iteratorTypes,
+    ValueRange steps, ArrayRef<StringRef> iteratorTypes,
     ArrayRef<linalg::ProcInfo> procInfo,
     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
     SmallVectorImpl<Value> &ivStorage) {
@@ -692,7 +690,7 @@ static void generateParallelLoopNest(
 template <>
 void GenerateLoopNest<scf::ParallelOp>::doit(
     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
-    ArrayRef<Attribute> iteratorTypes,
+    ArrayRef<StringRef> iteratorTypes,
     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
                                   ValueRange)>
         bodyBuilderFn,

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index ddcf839b1822e..88bd885393b7b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -214,7 +214,7 @@ static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
 /// as we use adj matrix for the graph.
 /// The sorted result will put the first Reduction iterator to the
 /// latest possible index.
-static bool topSortOptimal(unsigned n, ArrayRef<Attribute> iteratorTypes,
+static bool topSortOptimal(unsigned n, ArrayRef<StringRef> iteratorTypes,
                            std::vector<unsigned> &topSort,
                            std::vector<unsigned> &inDegree,
                            std::vector<std::vector<bool>> &adjM) {
@@ -289,7 +289,7 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
   unsigned n = op.getNumLoops();
   std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
   std::vector<unsigned> inDegree(n, 0); // in-degree of each node.
-  auto iteratorTypes = op.iterator_types().getValue();
+  auto iteratorTypes = op.getIteratorTypesArray();
   // Iterate over the indexing maps of every tensor in the tensor expression.
   for (OpOperand *t : op.getInputAndOutputOperands()) {
     // Skip tensor during cycle resolution.
@@ -361,7 +361,7 @@ static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op,
   // An all-dense annotated "sparse" output tensor becomes a linearized random
   // access 1-dim memref. Also admissible since insertions cannot occur.
   bool allDense = true;
-  auto iteratorTypes = op.iterator_types().getValue();
+  auto iteratorTypes = op.getIteratorTypesArray();
   unsigned numLoops = iteratorTypes.size();
   for (unsigned i = 0; i < numLoops; i++)
     if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) ||
@@ -1299,7 +1299,7 @@ static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   unsigned fb = indices.find_first();
   unsigned tensor = merger.tensor(fb);
   assert(idx == merger.index(fb));
-  auto iteratorTypes = op.iterator_types().getValue();
+  auto iteratorTypes = op.getIteratorTypesArray();
   bool isReduction = linalg::isReductionIterator(iteratorTypes[idx]);
   bool isSparse = merger.isDimLevelType(fb, DimLvlType::kCompressed) ||
                   merger.isDimLevelType(fb, DimLvlType::kSingleton);


        


More information about the Mlir-commits mailing list