[Mlir-commits] [mlir] c38d9cf - [mlir] Remove iterator_types() method from LinalgStructuredInterface.

Oleg Shyshkov llvmlistbot at llvm.org
Thu Oct 13 00:53:11 PDT 2022


Author: Oleg Shyshkov
Date: 2022-10-13T07:52:43Z
New Revision: c38d9cf20e7468a2618dc23fcdc66e79c925aff5

URL: https://github.com/llvm/llvm-project/commit/c38d9cf20e7468a2618dc23fcdc66e79c925aff5
DIFF: https://github.com/llvm/llvm-project/commit/c38d9cf20e7468a2618dc23fcdc66e79c925aff5.diff

LOG: [mlir] Remove iterator_types() method from LinalgStructuredInterface.

`getIteratorTypesArray` should be used instead. It's a better substitute for all the current usages of the interface.

The current `ArrayAttr iterator_types()` has a few problems:
* It creates an assumption operation has iterators types as an attribute, but it's not always the case. Sometime iterator types can be inferred from other attribute, or they're just static.
* ArrayAttr is an obscure contained and required extracting values in the client code.
* Makes it hard to migrate iterator types from strings to enums ([RFC](https://discourse.llvm.org/t/rfc-enumattr-for-iterator-types-in-linalg/64535/9)).

Concrete ops, like `linalg.generic` will still have iterator types as an attribute if needed.

As a side effect, this change helps a bit with migration to prefixed accessors.

Differential Revision: https://reviews.llvm.org/D135765

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 268587e312077..69871faa5a2cc 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -497,28 +497,21 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         return $_op.getBody();
       }]
     >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the iterator types attribute within the current operation.
-      }],
-      /*retTy=*/"ArrayAttr",
-      /*methodName=*/"iterator_types",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        return $_op.getIteratorTypes();
-      }]
-    >,
     InterfaceMethod<
       /*desc=*/[{
         Return iterator types in the current operation.
+
+        Default implementation assumes that the operation has an attribute
+        `iterator_types`, but it's not always the case. Sometimes iterator types
+        can be infered from other parameters and in such cases default
+        getIteratorTypesArray should be overriden.
       }],
       /*retTy=*/"SmallVector<StringRef>",
       /*methodName=*/"getIteratorTypesArray",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        auto range = $_op.iterator_types().template getAsValueRange<StringAttr>();
+        auto range = $_op.getIteratorTypes().template getAsValueRange<StringAttr>();
         return {range.begin(), range.end()};
       }]
     >,
@@ -773,9 +766,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     LogicalResult reifyResultShapes(OpBuilder &b,
         ReifiedRankedShapedTypeDims &reifiedReturnShapes);
 
-    // TODO: Remove once prefixing is flipped.
-    ArrayAttr getIteratorTypes() { return iterator_types(); }
-
     SmallVector<StringRef> getIteratorTypeNames() {
       return getIteratorTypesArray();
     }

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 6bcf509844cb8..2619ad1186408 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -264,7 +264,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
 
   let extraClassDeclaration = structuredOpsBaseDecls # [{
     // Implement functions necessary for LinalgStructuredInterface.
-    ArrayAttr getIteratorTypes();
+    SmallVector<StringRef> getIteratorTypesArray();
     ArrayAttr getIndexingMaps();
     std::string getLibraryCallName() {
       return "op_has_no_registered_library_name";
@@ -334,7 +334,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
 
   let extraClassDeclaration = structuredOpsBaseDecls # [{
     // Declare functions necessary for LinalgStructuredInterface.
-    ArrayAttr getIteratorTypes();
+    SmallVector<StringRef> getIteratorTypesArray();
     ArrayAttr getIndexingMaps();
     std::string getLibraryCallName() {
       return "op_has_no_registered_library_name";

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 61b938601e103..c2705a383550a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1393,11 +1393,9 @@ LogicalResult MapOp::verify() {
   return success();
 }
 
-ArrayAttr MapOp::getIteratorTypes() {
+SmallVector<StringRef> MapOp::getIteratorTypesArray() {
   int64_t rank = getInit().getType().getRank();
-  return Builder(getContext())
-      .getStrArrayAttr(
-          SmallVector<StringRef>(rank, getParallelIteratorTypeName()));
+  return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
 }
 
 ArrayAttr MapOp::getIndexingMaps() {
@@ -1435,13 +1433,13 @@ void ReduceOp::getAsmResultNames(
     setNameFn(getResults().front(), "reduced");
 }
 
-ArrayAttr ReduceOp::getIteratorTypes() {
+SmallVector<StringRef> ReduceOp::getIteratorTypesArray() {
   int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
   SmallVector<StringRef> iteratorTypes(inputRank,
                                        getParallelIteratorTypeName());
   for (int64_t reductionDim : getDimensions())
     iteratorTypes[reductionDim] = getReductionIteratorTypeName();
-  return Builder(getContext()).getStrArrayAttr(iteratorTypes);
+  return iteratorTypes;
 }
 
 ArrayAttr ReduceOp::getIndexingMaps() {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index a9113f0d05713..66d55dcf5c713 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -92,11 +92,9 @@ struct LinalgOpTilingInterface
   /// Return the loop iterator type.
   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
     LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
-    return llvm::to_vector(
-        llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) {
-          return utils::symbolizeIteratorType(
-                     strAttr.cast<StringAttr>().getValue())
-              .value();
+    return llvm::to_vector(llvm::map_range(
+        concreteOp.getIteratorTypesArray(), [](StringRef iteratorType) {
+          return utils::symbolizeIteratorType(iteratorType).value();
         }));
   }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index a28aa3b72ff43..2458dabef56c9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -250,7 +250,7 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
     // Fuse producer and consumer into a new generic op.
     auto fusedOp = rewriter.create<GenericOp>(
         loc, op.getResult(0).getType(), inputOps, outputOps,
-        rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.iterator_types(),
+        rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(),
         /*doc=*/nullptr, /*library_call=*/nullptr);
     Block &prodBlock = prod.getRegion().front();
     Block &consBlock = op.getRegion().front();

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 21ee7b5119763..1418ed4da4a4a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1857,7 +1857,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
     if (op.getNumOutputs() != 1)
       return failure();
     unsigned numTensors = op.getNumInputsAndOutputs();
-    unsigned numLoops = op.iterator_types().getValue().size();
+    unsigned numLoops = op.getNumLoops();
     Merger merger(numTensors, numLoops);
     if (!findSparseAnnotations(merger, op))
       return failure();

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 1fca43ddff411..7f329126b2892 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2816,7 +2816,7 @@ def TestLinalgConvOp :
       return ®ionBuilder;
     }
 
-    mlir::ArrayAttr iterator_types() {
+    mlir::ArrayAttr getIteratorTypes() {
       return getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
     }
 
@@ -2875,7 +2875,7 @@ def TestLinalgFillOp :
       return ®ionBuilder;
     }
 
-    mlir::ArrayAttr iterator_types() {
+    mlir::ArrayAttr getIteratorTypes() {
       return getOperation()->getAttrOfType<mlir::ArrayAttr>("iterator_types");
     }
 

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index 205cde22c9970..51557f519772a 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -235,7 +235,7 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: value
 
-#       IMPL:  Test3Op::iterator_types() {
+#       IMPL:  Test3Op::getIteratorTypesArray() {
 #  IMPL-NEXT:    int64_t rank = getRank(getOutputOperand(0));
 
 #       IMPL:  Test3Op::getIndexingMaps() {

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 1bea98fe71bfb..8156bb97a32f3 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -553,7 +553,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
 
     let extraClassDeclaration = structuredOpsBaseDecls # [{{
       // Auto-generated.
-      ArrayAttr iterator_types();
+      SmallVector<StringRef> getIteratorTypesArray();
       ArrayAttr getIndexingMaps();
       static void regionBuilder(ImplicitLocOpBuilder &b,
                                 Block &block, ArrayRef<NamedAttribute> attrs);
@@ -587,24 +587,24 @@ static const char structuredOpBuilderFormat[] = R"FMT(
   }]>
 )FMT";
 
-// The iterator_types() method for structured ops. Parameters:
+// The getIteratorTypesArray() method for structured ops. Parameters:
 // {0}: Class name
 // {1}: Comma interleaved iterator type names.
 static const char structuredOpIteratorTypesFormat[] =
     R"FMT(
-ArrayAttr {0}::iterator_types() {{
-  return Builder(getContext()).getStrArrayAttr(SmallVector<StringRef>{{ {1} });
+SmallVector<StringRef> {0}::getIteratorTypesArray() {{
+  return SmallVector<StringRef>{{ {1} };
 }
 )FMT";
 
-// The iterator_types() method for rank polymorphic structured ops. Parameters:
+// The getIteratorTypesArray() method for rank polymorphic structured ops.
+// Parameters:
 // {0}: Class name
 static const char rankPolyStructuredOpIteratorTypesFormat[] =
     R"FMT(
-ArrayAttr {0}::iterator_types() {{
+SmallVector<StringRef> {0}::getIteratorTypesArray() {{
   int64_t rank = getRank(getOutputOperand(0));
-  return Builder(getContext()).getStrArrayAttr(
-    SmallVector<StringRef>(rank, getParallelIteratorTypeName()));
+  return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
 }
 )FMT";
 


        


More information about the Mlir-commits mailing list