[Mlir-commits] [mlir] 1e7a6c0 - [mlir][linalg] Add getIteratorTypesArray to LinalgInterface.

Oleg Shyshkov llvmlistbot at llvm.org
Tue Sep 27 07:31:56 PDT 2022


Author: Oleg Shyshkov
Date: 2022-09-27T14:30:50Z
New Revision: 1e7a6c0874f57014132b857b2d201d6aaa75feee

URL: https://github.com/llvm/llvm-project/commit/1e7a6c0874f57014132b857b2d201d6aaa75feee
DIFF: https://github.com/llvm/llvm-project/commit/1e7a6c0874f57014132b857b2d201d6aaa75feee.diff

LOG: [mlir][linalg] Add getIteratorTypesArray to LinalgInterface.

Summary:
Most of the code that gets `iterator_types` from LinalgInterface is forced to
extract values from an `Attribute`. As a result, the usage pattern looks like
this:

```
SmallVector<StringRef> iterators = llvm::to_vector<4>(linalgOp.iterator_types().getAsValueRange<StringAttr>());
```

It also forces all operations that implement LinalgOp interface to have
`iterator_types` attribute even when the information can be easily infered from
other parameters.

In perfect future, `getIteratorTypeArray` should be the only method to get
iterator types from the interface. The default implementation can rely on
`iterator_types` attribute though.

The name `getIteratorTypeArray` was picked to be consistent with existing
`getIndexingMapsArray`.

This patch add a few sample usages. More cleanups will follow.

Differential Revision: https://reviews.llvm.org/D134729

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
    mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 64c2bd1b28510..38031704dbf5f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -491,6 +491,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         return $_op.iterator_types();
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return iterator types in the current operation.
+      }],
+      /*retTy=*/"SmallVector<StringRef>",
+      /*methodName=*/"getIteratorTypesArray",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto range = $_op.iterator_types().template getAsValueRange<StringAttr>();
+        return {range.begin(), range.end()};
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return true if the indexing map is depending on the current op instance.

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 6e4c2fc9d7393..032578fc5e25b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -297,8 +297,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) {
       !indexingMaps.back().isProjectedPermutation())
     return MatchConvolutionResult::NotProjectedPermutations;
 
-  auto iteratorTypesRange =
-      linalgOp.iterator_types().getAsValueRange<StringAttr>();
+  auto iteratorTypesRange = linalgOp.getIteratorTypesArray();
 
   llvm::SmallDenseSet<unsigned> outputDims =
       getPreservedDims(indexingMaps.back());

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index d58dc4a6540f5..15ba43cb98944 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -438,8 +438,7 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
       resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
     GenericOp replacementOp = rewriter.create<GenericOp>(
         loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
-        llvm::to_vector<4>(genericOp.getIteratorTypes()
-                               .template getAsValueRange<StringAttr>()));
+        genericOp.getIteratorTypesArray());
     rewriter.inlineRegionBefore(genericOp.getRegion(),
                                 replacementOp.getRegion(),
                                 replacementOp.getRegion().begin());

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 97dd0c448403f..94526e140298c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -467,9 +467,8 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
                             .isProjectedPermutation();
                       }) &&
          genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
-         llvm::all_of(genericOp.getIteratorTypes(), [](Attribute attr) {
-           return attr.cast<StringAttr>().getValue() ==
-                  getParallelIteratorTypeName();
+         llvm::all_of(genericOp.getIteratorTypesArray(), [](StringRef it) {
+           return it == getParallelIteratorTypeName();
          });
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 4f384c33e9945..97eee8b3e5091 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -53,8 +53,7 @@ FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
   SmallVector<Value> inputOperands = linalgOp.getInputOperands();
   SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
   SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
-  SmallVector<StringRef> iterators = llvm::to_vector<4>(
-      linalgOp.iterator_types().getAsValueRange<StringAttr>());
+  SmallVector<StringRef> iterators = linalgOp.getIteratorTypesArray();
   SmallVector<RankedTensorType> resultTypes = linalgOp.getOutputTensorTypes();
   SmallVector<Type> types(resultTypes.begin(), resultTypes.end());
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
index e8f5372e4fdc0..a53e7c5ff086e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
@@ -61,9 +61,7 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
     SmallVector<Value> outputOperands = genericOp.getOutputOperands();
     auto newOp = rewriter.create<GenericOp>(
         loc, genericOp->getResultTypes(), newOperands, outputOperands,
-        newIndexingMaps,
-        llvm::to_vector<4>(genericOp.getIteratorTypes()
-                               .template getAsValueRange<StringAttr>()));
+        newIndexingMaps, genericOp.getIteratorTypesArray());
     rewriter.cloneRegionBefore(genericOp.getRegion(), newOp.getRegion(),
                                newOp.getRegion().begin());
 


        


More information about the Mlir-commits mailing list