[Mlir-commits] [mlir] bfaf535 - [mlir][Linalg] Refactor in preparation for automatic Linalg "named" ops.

Nicolas Vasilache llvmlistbot at llvm.org
Wed Feb 12 11:50:10 PST 2020


Author: Nicolas Vasilache
Date: 2020-02-12T14:47:40-05:00
New Revision: bfaf535791897f3cc2af40d4f5a677489ad25940

URL: https://github.com/llvm/llvm-project/commit/bfaf535791897f3cc2af40d4f5a677489ad25940
DIFF: https://github.com/llvm/llvm-project/commit/bfaf535791897f3cc2af40d4f5a677489ad25940.diff

LOG: [mlir][Linalg] Refactor in preparation for automatic Linalg "named" ops.

This revision prepares the ground for declaratively defining Linalg "named" ops.
Such named ops form the backbone of operations that are ubiquitous in the ML
application domain.

This revision closely related to the definition of a "Tensor Computation
Primitives Dialect" and demonstrates that ops can be expressed as declarative
configurations of the `linalg.generic` op.

Differential Revision: https://reviews.llvm.org/D74491

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index 305159dba2f5..2e269890b43f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Linalg/IR/LinalgTraits.h"
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index d31424ac599f..8914bcfe546b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -39,19 +39,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     // Loop types handling.
     //========================================================================//
     InterfaceMethod<
-      "Query the number of parallel loops within the current operation.",
+      "Return the number of parallel loops within the current operation.",
       "unsigned", "getNumParallelLoops"
     >,
     InterfaceMethod<
-      "Query the number of reduction loops within the current operation.",
+      "Return the number of reduction loops within the current operation.",
       "unsigned", "getNumReductionLoops"
     >,
     InterfaceMethod<
-      "Query the number of window loops within the current operation.",
+      "Return the number of window loops within the current operation.",
       "unsigned", "getNumWindowLoops"
     >,
     InterfaceMethod<
-      "Query the number of loops within the current operation.",
+      "Return the number of loops within the current operation.",
       "unsigned", "getNumLoops">,
 
     InterfaceMethod<
@@ -63,10 +63,10 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     // Input arguments handling.
     //========================================================================//
     InterfaceMethod<
-      "Query the number of inputs from the current operation.",
+      "Return the number of inputs from the current operation.",
       "unsigned", "getNumInputs"
     >,
-    InterfaceMethod<"Query the input view at the given index.",
+    InterfaceMethod<"Return the input view at the given index.",
       "Value ", "getInput", (ins "unsigned":$i)
     >,
     InterfaceMethod<[{
@@ -76,41 +76,40 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       "llvm::Optional<unsigned>", "getIndexOfInput", (ins "Value ":$v)
     >,
     InterfaceMethod<
-      "Query the input operands from the current operation.",
+      "Return the input operands from the current operation.",
       "Operation::operand_range", "getInputs"
     >,
     InterfaceMethod<[{
-        Query the type of the input shape at the given index.
+        Return the type of the input shape at the given index.
       }], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>,
     InterfaceMethod<[{
-        Query the subset of input operands that are of ranked tensor type.
+        Return the subset of input operands that are of ranked tensor type.
       }], "SmallVector<RankedTensorType, 4>", "getInputTensorTypes">,
 
-
     //========================================================================//
     // Output arguments handling.
     //========================================================================//
     InterfaceMethod<
-      "Query the number of outputs from the current operation.",
+      "Return the number of outputs from the current operation.",
       "unsigned", "getNumOutputs"
     >,
-    InterfaceMethod<"Query the output buffer at the given index.",
+    InterfaceMethod<"Return the output buffer at the given index.",
       "Value ", "getOutputBuffer", (ins "unsigned":$i)
     >,
     InterfaceMethod<[{
-        Query the index of the given buffer value, or `None` if the value is not
-        part of the output buffers.
+        Return the index of the given buffer value, or `None` if the value is
+        not part of the output buffers.
       }],
       "llvm::Optional<unsigned>", "getIndexOfOutputBuffer", (ins "Value ":$view)
     >,
     InterfaceMethod<[{
-        Query the type of the output buffer at the given index.
+        Return the type of the output buffer at the given index.
       }], "MemRefType", "getOutputBufferType", (ins "unsigned":$i)>,
     InterfaceMethod<[{
-        Query the results that are of ranked tensor type.
+        Return the results that are of ranked tensor type.
       }], "SmallVector<RankedTensorType, 4>", "getOutputTensorTypes">,
     InterfaceMethod<
-      "Query the output buffers (operands) from the current operation.",
+      "Return the output buffers (operands) from the current operation.",
       "Operation::operand_range", "getOutputBuffers"
     >,
 
@@ -136,18 +135,44 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     // Other interface methods.
     //========================================================================//
     InterfaceMethod<
-      "Query the iterator types attribute within the current operation.",
+      "Return the reference iterators for this named op (if any are specied). "
+      "These reference iterators are used to specify the default behavior of "
+      "the op. Typically this would be a static method but in order to allow "
+      "rank-polymorphic ops, this needs to be per object instance. Named ops "
+      "must define referenceIterators, even if empty for the 0-D case. "
+      "Generic ops on the other hand have a None `referenceIterators`",
+      "llvm::Optional<SmallVector<StringRef, 8>>", "referenceIterators"
+    >,
+    InterfaceMethod<
+      "Return the reference indexing maps for this named op (if any are "
+      "specified). Typically this would be a static method but in order to "
+      "allow rank-polymorphic ops, this needs to be per object instance. Named "
+      "ops must define referenceIterators, even if empty for the 0-D case. "
+      "Generic ops on the other hand have a None `referenceIndexingMaps`",
+      "llvm::Optional<SmallVector<AffineMap, 8>>", "referenceIndexingMaps"
+    >,
+    InterfaceMethod<
+      "Return the iterator types attribute within the current operation.",
       "ArrayAttr", "iterator_types"
     >,
     InterfaceMethod<
-      "Query the indexing maps attribute within the current operation.",
+      "Return the indexing maps attribute within the current operation.",
       "ArrayAttr", "indexing_maps"
     >,
+    InterfaceMethod<"Return the input or output indexing map at index `i`.",
+      "AffineMap", "getIndexingMap", (ins "unsigned":$i)
+    >,
+    InterfaceMethod<"Return the input indexing map at index `i`.",
+      "AffineMap", "getInputIndexingMap", (ins "unsigned":$i)
+    >,
+    InterfaceMethod<"Return the output indexing map at index `i`.",
+      "AffineMap", "getOutputIndexingMap", (ins "unsigned":$i)
+    >,
     InterfaceMethod<[{
-        Query whether the op has only MemRef input and outputs.
+        Return whether the op has only MemRef input and outputs.
       }], "bool", "hasBufferSemantics">,
     InterfaceMethod<[{
-        Query whether the op has only RankedTensor input and outputs.
+        Return whether the op has only RankedTensor input and outputs.
       }], "bool", "hasTensorSemantics">,
 
     //========================================================================//
@@ -204,7 +229,7 @@ class LinalgStructured_Op<string mnemonic, list<OpTrait> props>
 }
 
 ////////////////////////////////////////////////////////////////////////////////
-// Concrete Linalg ops.
+// Named Linalg ops, implemented as special configurations of a generic op.
 ////////////////////////////////////////////////////////////////////////////////
 def CopyOp : LinalgStructured_Op<"copy", [NInputs<1>, NOutputs<1>]> {
   let description = [{
@@ -266,14 +291,19 @@ def CopyOp : LinalgStructured_Op<"copy", [NInputs<1>, NOutputs<1>]> {
       builder, result, input, output, AffineMapAttr(), AffineMapAttr());
   }]>];
   let extraClassDeclaration = libraryCallName # [{
+    // Defined in C++ for now.
+    // TODO(ntv): auto-generate.
     ArrayAttr indexing_maps();
 
-    ArrayAttr iterator_types() {
+    // Rank-polymorphic.
+    //   filling_value -> O(ivs) with parallel iterators.
+    llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
       unsigned nPar = input().getType().cast<ShapedType>().getRank();
-      MLIRContext *ctx = getContext();
-      SmallVector<Attribute, 8> iters(
-        nPar, StringAttr::get(getParallelIteratorTypeName(), ctx));
-      return ArrayAttr::get(iters, ctx);
+      return SmallVector<StringRef, 8>(nPar, getParallelIteratorTypeName());
+    }
+
+    llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
+      llvm_unreachable("NYI referenceIndexingMaps for CopyOp");
     }
   }];
   let verifier = [{ return ::verify(*this); }];
@@ -282,21 +312,24 @@ def CopyOp : LinalgStructured_Op<"copy", [NInputs<1>, NOutputs<1>]> {
 }
 
 def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
-  let arguments = (ins AnyStridedMemRef:$input,
+  let arguments = (ins AnyStridedMemRef:$output,
                    AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>:$value);
   let extraClassDeclaration = libraryCallName # [{
+    // Defined in C++ for now.
+    // TODO(ntv): auto-generate.
     ArrayAttr indexing_maps();
 
-    ArrayAttr iterator_types() {
-      unsigned nPar = input().getType().cast<ShapedType>().getRank();
-      MLIRContext *ctx = getContext();
-      SmallVector<Attribute, 8> iters(
-        nPar, StringAttr::get(getParallelIteratorTypeName(), ctx));
-      return ArrayAttr::get(iters, ctx);
+    // Rank-polymorphic.
+    //   filling_value -> O(ivs) with parallel iterators.
+    llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
+      unsigned nPar = output().getType().cast<ShapedType>().getRank();
+      return SmallVector<StringRef, 8>(nPar, getParallelIteratorTypeName());
     }
-  }];
-  let verifier = [{ return ::verify(*this); }];
 
+    llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
+      llvm_unreachable("NYI referenceIndexingMaps for CopyOp");
+    }
+  }];
   let hasFolder = 1;
 }
 
@@ -305,12 +338,16 @@ def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> {
                        AnyStridedMemRefOfRank<1>,
                        AnyStridedMemRefOfRank<0>);
   let extraClassDeclaration = libraryCallName # [{
+    // Defined in C++ for now.
+    // TODO(ntv): auto-generate.
     ArrayAttr indexing_maps();
 
-    ArrayAttr iterator_types() {
-      MLIRContext *ctx = getContext();
-      return ArrayAttr::get(
-        StringAttr::get(getReductionIteratorTypeName(), ctx), ctx);
+    llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
+      return SmallVector<StringRef, 8>{getReductionIteratorTypeName()};
+    }
+
+    llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
+      llvm_unreachable("NYI referenceIndexingMaps for DotOp");
     }
   }];
 
@@ -322,14 +359,18 @@ def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> {
                        AnyStridedMemRefOfRank<1>,
                        AnyStridedMemRefOfRank<1>);
   let extraClassDeclaration = libraryCallName # [{
+    // Defined in C++ for now.
+    // TODO(ntv): auto-generate.
     ArrayAttr indexing_maps();
 
-    ArrayAttr iterator_types() {
-      MLIRContext *ctx = getContext();
-      Attribute iters[2]{
-        StringAttr::get(getParallelIteratorTypeName(), ctx),
-        StringAttr::get(getReductionIteratorTypeName(), ctx)};
-      return ArrayAttr::get(iters, ctx);
+    llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
+      return SmallVector<StringRef, 8>{
+        getParallelIteratorTypeName(),
+        getReductionIteratorTypeName()};
+    }
+
+    llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
+      llvm_unreachable("NYI referenceIndexingMaps for MatvecOp");
     }
   }];
 
@@ -341,15 +382,19 @@ def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> {
                        AnyStridedMemRefOfRank<2>,
                        AnyStridedMemRefOfRank<2>);
   let extraClassDeclaration = libraryCallName # [{
+    // Defined in C++ for now.
+    // TODO(ntv): auto-generate.
     ArrayAttr indexing_maps();
 
-    ArrayAttr iterator_types() {
-      MLIRContext *ctx = getContext();
-      Attribute iters[3]{
-        StringAttr::get(getParallelIteratorTypeName(), ctx),
-        StringAttr::get(getParallelIteratorTypeName(), ctx),
-        StringAttr::get(getReductionIteratorTypeName(), ctx)};
-      return ArrayAttr::get(iters, ctx);
+    llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
+      return SmallVector<StringRef, 8>{
+        getParallelIteratorTypeName(),
+        getParallelIteratorTypeName(),
+        getReductionIteratorTypeName()};
+    }
+
+    llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
+      llvm_unreachable("NYI referenceIndexingMaps for MatmulOp");
     }
   }];
 
@@ -387,11 +432,13 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
     unsigned getNumInputFeatureDimensions() { return 1; }
     unsigned getNumOutputFeatureDimensions() { return 1; }
 
+    // Defined in C++ for now.
+    // TODO(ntv): auto-generate.
     ArrayAttr indexing_maps();
 
-    ArrayAttr iterator_types() {
+    llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
       // Outer parallel loops are always the number of output dimensions; i.e.
-      // [ b, xs, q] in the TF notation above.
+      // [b, xs, q] in the TF notation above.
       unsigned nPar = getOutputShapedType(0).getRank();
       unsigned nRed = getNumInputFeatureDimensions();
       // Window loops are a special kind of reduction that is never tiled or
@@ -400,13 +447,11 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
       // This may evolve in the future.
       unsigned nWin =
         nPar - getNumBatchDimensions() - getNumInputFeatureDimensions();
-      MLIRContext *ctx = getContext();
-      SmallVector<Attribute, 8> iters(
-        nPar, StringAttr::get(getParallelIteratorTypeName(), ctx));
+      SmallVector<StringRef, 8> iters(nPar, getParallelIteratorTypeName());
       iters.reserve(nPar + nRed + nWin);
-      iters.append(nRed, StringAttr::get(getReductionIteratorTypeName(), ctx));
-      iters.append(nWin, StringAttr::get(getWindowIteratorTypeName(), ctx));
-      return ArrayAttr::get(iters, ctx);
+      iters.append(nRed, getReductionIteratorTypeName());
+      iters.append(nWin, getWindowIteratorTypeName());
+      return iters;
     }
 
     int64_t getStride(unsigned i) {
@@ -422,6 +467,10 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
       return dilations()->getValue()[i]
         .cast<IntegerAttr>().getValue().getSExtValue();
     }
+
+    llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
+      llvm_unreachable("NYI referenceIndexingMaps for MatmulOp");
+    }
   }];
 
   let verifier = [{ return ::verify(*this); }];
@@ -438,6 +487,9 @@ class LinalgOperandOfRank<int rank>: Type<
     CPred<"$_self.cast<ShapedType>().getRank() == " # rank>]
   >>;
 
+////////////////////////////////////////////////////////////////////////////////
+// Generic Linalg ops.
+////////////////////////////////////////////////////////////////////////////////
 class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
   let arguments = (ins Variadic<LinalgOperand>:$views,
                    I64Attr:$args_in,
@@ -457,34 +509,36 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
         getIteratorTypesAttrName()
       };
     }
+
     unsigned getNumInputs() { return args_in().getSExtValue(); }
+
     unsigned getNumOutputs() { return args_out().getSExtValue(); }
+
     FuncOp getFunction() {
       auto moduleOp = getParentOfType<ModuleOp>();
       return fun().hasValue() ?
         moduleOp.lookupSymbol<FuncOp>(fun().getValue()) : FuncOp();
     }
+
     StringRef getLibraryCallName() {
       return library_call().hasValue() ? library_call().getValue() : "";
     }
-    AffineMap getIndexingMap(unsigned i) {
-      assert(i < getNumInputsAndOutputs());
-      return indexing_maps().getValue()[i].cast<AffineMapAttr>().getValue();
-    }
-    AffineMap getInputIndexingMap(unsigned i) {
-      assert(i < getNumInputs());
-      return indexing_maps().getValue()[i].cast<AffineMapAttr>().getValue();
-    }
-    AffineMap getOutputIndexingMap(unsigned i) {
-      assert(i < getNumOutputs());
-      return indexing_maps().getValue()[i + getNumInputs()]
-          .cast<AffineMapAttr>().getValue();
-    }
+
+    llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
+      llvm_unreachable(
+        "No such thing as reference iterator types for a generic op.");
+     }
+
+    llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
+      llvm_unreachable(
+        "No such thing as reference indexing maps for a generic op.");
+     }
   }];
   let printer = [{ return ::print(p, *this); }];
   let parser = [{ return ::parseGenericOp(parser, result); }];
 }
 
+/// Index-free GenericOp.
 def GenericOp : GenericOpBase<"generic"> {
   let description = [{
     Generic Linalg op form where the key properties of the computation are
@@ -609,6 +663,8 @@ def GenericOp : GenericOpBase<"generic"> {
   let hasFolder = 1;
 }
 
+/// GenericOp with Indexing (i.e. multi-for style in which the region is passed
+/// the enclosing loop induction variables)
 def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
   let description = [{
     Indexed Generic Linalg op form where the key properties of the computation

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
index bfe528dabd3c..31b462ab8ba6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Support/LLVM.h"
@@ -214,6 +215,103 @@ class StructuredOpTraits
   //==========================================================================//
   // Other interface methods.
   //==========================================================================//
+
+  // Get or build the indexing_maps ArrayAttr.
+  ArrayAttr iterator_types() {
+    // Return the attribute if it is present.
+    if (auto attr = this->getOperation()->getAttr("iterator_types"))
+      return attr.template cast<ArrayAttr>();
+
+    // If not, form the attribute using the reference iterator types for the
+    // ConcreteType.
+    auto maybeReferenceIteratorTypes =
+        cast<ConcreteType>(this->getOperation()).referenceIterators();
+
+    // If there is no reference, this must be a generic op.
+    // TODO(ntv): Traits are used to define ops. Split into cpp to avoid
+    // cyclic dependency.
+    auto name = this->getOperation()->getName().getStringRef();
+    if (!maybeReferenceIteratorTypes && name != "generic" &&
+        name != "indexed_generic") {
+      this->getOperation()->dump();
+      llvm_unreachable("Op missing ");
+    }
+
+    // If we have a reference, build the reference attribute.
+    auto *ctx = this->getOperation()->getContext();
+    auto attrRange = llvm::map_range(*maybeReferenceIteratorTypes,
+                                     [ctx](StringRef str) -> Attribute {
+                                       return StringAttr::get(str, ctx);
+                                     });
+    auto attr = ArrayAttr::get(llvm::to_vector<4>(attrRange), ctx);
+    // TODO(ntv): Need to memoize this. Can't just store as an attribute atm as
+    // it will impact parser, printer and tests.
+    // this->getOperation()->setAttr("iterator_types", attr);
+    return attr;
+  }
+
+  // Get or build the indexing_maps ArrayAttr.
+  ArrayAttr indexing_maps() {
+    // Return the attribute if it is present.
+    if (auto attr = this->getOperation()->getAttr("indexing_maps"))
+      return attr.template cast<ArrayAttr>();
+
+    // If not, form the attribute using the reference indexing map for the
+    // ConcreteType.
+    auto maybeReferenceIndexingMaps =
+        cast<ConcreteType>(this->getOperation()).referenceIndexingMaps();
+
+    // If there is no reference, this must be a generic op.
+    auto name = this->getOperation()->getName().getStringRef();
+    if (!maybeReferenceIndexingMaps && name != "generic" &&
+        name != "indexed_generic") {
+      this->getOperation()->dump();
+      llvm_unreachable("Op missing referenceIndexingMaps");
+    }
+
+    // If we have a reference, build the reference attribute and set it in the
+    // op before returning.
+    auto *ctx = this->getOperation()->getContext();
+    auto attrRange =
+        llvm::map_range(*maybeReferenceIndexingMaps, [ctx](AffineMap map) {
+          // 0-D corner case because there is no such thing as a concrete empty
+          // map type.
+          if (!map)
+            map = AffineMap::get(0, 0, getAffineConstantExpr(0, ctx));
+          return AffineMapAttr::get(map);
+        });
+    SmallVector<Attribute, 4> attrs{attrRange.begin(), attrRange.end()};
+    auto attr = ArrayAttr::get(attrs, ctx);
+    // TODO(ntv): Need to memoize this. Can't just store as an attribute atm as
+    // it will impact parser, printer and tests.
+    // this->getOperation()->setAttr("indexing_maps", attr);
+    return attr;
+  }
+
+  AffineMap getIndexingMap(unsigned i) {
+    assert(i < getNumInputsAndOutputs());
+    return indexing_maps()
+        .getValue()[i]
+        .template cast<AffineMapAttr>()
+        .getValue();
+  }
+
+  AffineMap getInputIndexingMap(unsigned i) {
+    assert(i < nInputs());
+    return indexing_maps()
+        .getValue()[i]
+        .template cast<AffineMapAttr>()
+        .getValue();
+  }
+
+  AffineMap getOutputIndexingMap(unsigned i) {
+    assert(i < nOutputs());
+    return indexing_maps()
+        .getValue()[i + nInputs()]
+        .template cast<AffineMapAttr>()
+        .getValue();
+  }
+
   /// Query whether the op has only buffer inputs and no returns.
   bool hasBufferSemantics() {
     return this->getOperation()->getNumResults() == 0 &&

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index fb18fbf02f38..c5fbea9d5802 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -866,6 +866,15 @@ static LogicalResult verify(ConvOp op) {
   return success();
 }
 
+static AffineMap extractOrIdentityMap(Optional<AffineMap> maybeMap,
+                                      unsigned rank, MLIRContext *context) {
+  if (maybeMap)
+    return maybeMap.getValue();
+  if (rank == 0)
+    return AffineMap();
+  return AffineMap::getMultiDimIdentityMap(rank, context);
+}
+
 namespace mlir {
 namespace linalg {
 
@@ -880,15 +889,6 @@ namespace linalg {
 } // namespace linalg
 } // namespace mlir
 
-static AffineMap extractOrIdentityMap(Optional<AffineMap> maybeMap,
-                                      unsigned rank, MLIRContext *context) {
-  if (maybeMap)
-    return maybeMap.getValue();
-  if (rank == 0)
-    return AffineMap();
-  return AffineMap::getMultiDimIdentityMap(rank, context);
-}
-
 // Returns `num` AffineDimExpr dimensions at positions [curIdx, curIdx + num)
 // and increments `curIdx` to `curIdx + num`.
 static SmallVector<AffineExpr, 4>
@@ -997,23 +997,15 @@ SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
         AffineMap::get(idx, 0, concat(concat(bs, ws), qs)),
         // output[b, x[0], ..., x[N-1], k]
         AffineMap::get(idx, 0, concat(concat(bs, xs), ks))};
-  } else if (auto genericOp = dyn_cast<GenericOp>(op)) {
-    SmallVector<AffineMap, 4> res;
-    unsigned nViews = genericOp.getNumInputsAndOutputs();
-    res.reserve(nViews);
-    for (unsigned i = 0, e = nViews; i < e; ++i) {
-      res.push_back(genericOp.getIndexingMap(i));
-    }
-    return res;
-  } else if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op)) {
-    SmallVector<AffineMap, 4> res;
-    unsigned nViews = indexedGenericOp.getNumInputsAndOutputs();
-    res.reserve(nViews);
-    for (unsigned i = 0, e = nViews; i < e; ++i)
-      res.push_back(indexedGenericOp.getIndexingMap(i));
-    return res;
   }
-  llvm_unreachable("Missing loopToOperandRangesMaps for op");
+  SmallVector<AffineMap, 4> res;
+  auto linalgOp = cast<LinalgOp>(op);
+  unsigned nViews = linalgOp.getNumInputsAndOutputs();
+  res.reserve(nViews);
+  for (unsigned i = 0, e = nViews; i < e; ++i)
+    res.push_back(linalgOp.getIndexingMap(i));
+  assert(nViews == linalgOp.indexing_maps().size());
+  return res;
 }
 
 static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {


        


More information about the Mlir-commits mailing list