[Mlir-commits] [mlir] 3bdd7fc - [mlir][Linalg] Add support to lower named ops to loops.

Nicolas Vasilache llvmlistbot at llvm.org
Thu Apr 30 10:49:13 PDT 2020


Author: Nicolas Vasilache
Date: 2020-04-30T13:45:17-04:00
New Revision: 3bdd7fcc3404001ad919da6b9acc677199793787

URL: https://github.com/llvm/llvm-project/commit/3bdd7fcc3404001ad919da6b9acc677199793787
DIFF: https://github.com/llvm/llvm-project/commit/3bdd7fcc3404001ad919da6b9acc677199793787.diff

LOG: [mlir][Linalg] Add support to lower named ops to loops.

This revision adds support to allow named ops to lower to loops.
Linalg.batch_matmul successfully lowers to loops and to LLVM.

In the process, this test also activates linalg to affine loops.
However padded convolutions to not lower to affine.load atm so this revision overrides the type of underlying load / store operation.

Differential Revision: https://reviews.llvm.org/D79135

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
    mlir/test/Dialect/Linalg/affine.mlir
    mlir/test/Dialect/Linalg/loops.mlir
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
index 1c427faff693..b7bba5a31011 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
@@ -351,10 +351,11 @@ template <typename ConcreteType>
 class NamedStructuredOpTraits
     : public OpTrait::TraitBase<ConcreteType, NamedStructuredOpTraits> {
 public:
-  llvm::Optional<SmallVector<StringRef, 8>> referenceIterators();
-  llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps();
-  std::function<void(OpBuilder &, Location, ArrayRef<Value>)>
-  emitScalarImplementation();
+  static SmallVector<StringRef, 8> referenceIterators(TypeRange inputTypes,
+                                                      TypeRange outputTypes);
+
+  static SmallVector<AffineMap, 8> referenceIndexingMaps(TypeRange inputTypes,
+                                                         TypeRange outputTypes);
 };
 
 } // namespace linalg

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 974bff525f96..82ae6de83c83 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -33,10 +33,9 @@ using namespace mlir::linalg;
 
 /// Forward declarations.
 template <typename NamedStructuredOpType>
-static void buildNamedStructuredOpRegion(Builder &builder,
-                                         OperationState &result,
-                                         TypeRange operandTypes,
-                                         TypeRange tensorResultTypes);
+static void buildNamedStructuredOpRegionAndAttributes(
+    Builder &builder, OperationState &result, TypeRange operandTypes,
+    TypeRange tensorResultTypes);
 template <typename NamedStructuredOpType>
 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
 template <typename NamedStructuredOpType>
@@ -1085,9 +1084,10 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
 //===----------------------------------------------------------------------===//
 
 template <typename NamedStructuredOpType>
-void buildNamedStructuredOpRegion(Builder &builder, OperationState &result,
-                                  TypeRange operandTypes,
-                                  TypeRange tensorResultTypes) {
+void buildNamedStructuredOpRegionAndAttributes(Builder &builder,
+                                               OperationState &result,
+                                               TypeRange operandTypes,
+                                               TypeRange tensorResultTypes) {
   Region &region = *result.addRegion();
   Block *body = new Block();
   // TODO: atm all operands go through getElementTypeOrSelf,
@@ -1102,12 +1102,24 @@ void buildNamedStructuredOpRegion(Builder &builder, OperationState &result,
   opBuilder.setInsertionPointToStart(&region.front());
   mlir::edsc::ScopedContext scope(opBuilder, builder.getUnknownLoc());
   NamedStructuredOpType::regionBuilder(*body);
+
+  auto indexingMaps = builder.getAffineMapArrayAttr(
+      NamedStructuredOpType::referenceIndexingMaps(operandTypes,
+                                                   tensorResultTypes));
+  result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
+
+  auto iterators =
+      builder.getStrArrayAttr(NamedStructuredOpType::referenceIterators(
+          operandTypes, tensorResultTypes));
+  result.addAttribute(getIteratorTypesAttrName(), iterators);
 }
 
 template <typename NamedStructuredOpType>
 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
+  std::array<StringRef, 2> silentAttrNames{getIndexingMapsAttrName(),
+                                           getIteratorTypesAttrName()};
   p << op.getOperationName() << ' ';
-  p.printOptionalAttrDict(op.getAttrs());
+  p.printOptionalAttrDict(op.getAttrs(), silentAttrNames);
   p << ' ' << op.getOperands();
   p << ": (" << op.getOperandTypes() << ")";
   auto outputTensorTypes = op.getResultTypes();
@@ -1139,7 +1151,7 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
   if (!tensorResultTypes.empty())
     result.addTypes(tensorResultTypes);
 
-  buildNamedStructuredOpRegion<NamedStructuredOpType>(
+  buildNamedStructuredOpRegionAndAttributes<NamedStructuredOpType>(
       parser.getBuilder(), result, operandTypes, tensorResultTypes);
 
   return parser.resolveOperands(operandsInfo, operandTypes,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index 4a6d54c97865..62a5a02ea3be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -78,11 +78,10 @@ SmallVector<Value, 4> emitLoopRanges(OpBuilder &b, Location loc, AffineMap map,
   return res;
 }
 
-template <typename OpType>
-static void
-inlineRegionAndEmitStdStore(OpType op, ArrayRef<Value> indexedValues,
-                            ArrayRef<SmallVector<Value, 8>> indexing,
-                            ArrayRef<Value> outputBuffers) {
+template <typename IndexedValueType, typename OpType>
+static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
+                                     ArrayRef<SmallVector<Value, 8>> indexing,
+                                     ArrayRef<Value> outputBuffers) {
   auto &b = ScopedContext::getBuilder();
   auto &block = op.region().front();
   BlockAndValueMapping map;
@@ -95,10 +94,10 @@ inlineRegionAndEmitStdStore(OpType op, ArrayRef<Value> indexedValues,
 
   Operation &terminator = block.back();
   assert(isa<YieldOp>(terminator) &&
-         "expected an yield op in the end of the region");
+         "expected a yield op in the end of the region");
   for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) {
-    std_store(map.lookupOrDefault(terminator.getOperand(i)), outputBuffers[i],
-              ArrayRef<Value>{indexing[i].begin(), indexing[i].end()});
+    IndexedValueType O(outputBuffers[i]);
+    O(indexing[i]) = map.lookupOrDefault(terminator.getOperand(i));
   }
 }
 
@@ -123,9 +122,36 @@ static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs,
 
 namespace {
 
-// Generic loop emitter, to be specialized on an op-per op basis.
-// TODO: Hook up to named ops interface and, later, retire when all named ops
-// are auto-generated.
+/// Emits the MLIR for the scalar part of the generic op by:
+///   1. Emitting load ops for each input and output view in order. This is
+///      achieved by applying the appropriate input or output map to the
+///      enclosing induction variables.
+///   2. Emitting a call to `op.fun()` that takes as arguments the scalars
+///      from point 1. above.
+///   3. Emitting store ops to store the results of 2. to the output
+///      views.
+///
+/// An example output may resemble:
+///
+/// ```
+///    loop.for %i = %c0 to %0 step %c1 {
+///      loop.for %j = %c0 to %1 step %c1 {
+///        loop.for %k = %c0 to %4 step %c1 {
+///          %11 = load %arg0[%i, %j] :
+///            memref<?x?xf32, stride_specification>
+///          %12 = load %arg1[%i, %j, %k] :
+///            memref<?x?x?xf32, stride_specification>
+///          %13 = load %arg2[%i, %k, %j] :
+///            memref<?x?x?xf32, stride_specification>
+///          %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
+///          store %14#0, %arg1[%i, %j, %k] :
+///            memref<?x?x?Xf32, stride_specification>
+///          store %14#1, %arg2[%i, %k, %j] :
+///            memref<?x?x?Xf32, stride_specification>
+///       }
+///      }
+///    }
+/// ```
 template <typename IndexedValueType, typename LinalgOpType>
 class LinalgScopedEmitter {
 public:
@@ -133,9 +159,43 @@ class LinalgScopedEmitter {
                                        LinalgOpType linalgOp) {
     assert(linalgOp.hasBufferSemantics() &&
            "expected linalg op with buffer semantics");
-    llvm_unreachable("NYI");
-    linalgOp.emitScalarImplementation()(ScopedContext::getBuilder(),
-                                        ScopedContext::getLocation(), allIvs);
+    auto b = ScopedContext::getBuilder();
+    auto loc = ScopedContext::getLocation();
+    unsigned nInputs = linalgOp.getNumInputs();
+    unsigned nOutputs = linalgOp.getNumOutputs();
+    SmallVector<Value, 4> indexedValues;
+    indexedValues.reserve(nInputs + nOutputs);
+
+    // TODO(mravishankar): Avoid the loads if the corresponding argument of the
+    // region has no uses.
+    // 1.a. Emit load from input views.
+    for (unsigned i = 0; i < nInputs; ++i) {
+      auto indexing = makeCanonicalAffineApplies(
+          b, loc, linalgOp.getInputIndexingMap(i), allIvs);
+      // Passing through IndexedValueType emits the proper load operation.
+      indexedValues.push_back(IndexedValueType(linalgOp.getInput(i))(indexing));
+    }
+    // 1.b. Emit load from output views.
+    for (unsigned i = 0; i < nOutputs; ++i) {
+      auto indexing = makeCanonicalAffineApplies(
+          b, loc, linalgOp.getOutputIndexingMap(i), allIvs);
+      // Passing through IndexedValueType emits the proper load operation.
+      indexedValues.push_back(
+          IndexedValueType(linalgOp.getOutputBuffer(i))(indexing));
+    }
+
+    // TODO(ntv): When a region inliner exists, use it.
+    // 2. Inline region, currently only works for a single basic block.
+    // 3. Emit store.
+    SmallVector<SmallVector<Value, 8>, 8> indexing;
+    SmallVector<Value, 8> outputBuffers;
+    for (unsigned i = 0; i < nOutputs; ++i) {
+      indexing.push_back(makeCanonicalAffineApplies(
+          b, loc, linalgOp.getOutputIndexingMap(i), allIvs));
+      outputBuffers.push_back(linalgOp.getOutputBuffer(i));
+    }
+    inlineRegionAndEmitStore<IndexedValueType>(linalgOp, indexedValues,
+                                               indexing, outputBuffers);
   }
 };
 
@@ -231,7 +291,7 @@ class LinalgScopedEmitter<IndexedValueType, ConvOp> {
 public:
   /// Returns the input value of convOp. If the indices in `imIdx` is out of
   /// boundary, returns 0 instead.
-  static Value getConvOpInput(ConvOp convOp, IndexedValueType im,
+  static Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
                               MutableArrayRef<Value> imIdx) {
     // TODO(ntv): add a level of indirection to linalg.generic.
     if (!convOp.padding())
@@ -293,7 +353,11 @@ class LinalgScopedEmitter<IndexedValueType, ConvOp> {
         makeCanonicalAffineApplies(b, loc, maps[1], allIvs));
     SmallVector<Value, 8> oIdx(
         makeCanonicalAffineApplies(b, loc, maps[2], allIvs));
-    IndexedValueType F(convOp.filter()), I(convOp.input()), O(convOp.output());
+
+    // Padded conv involves an affine.max in the memory access which is not
+    // allowed by affine.load. Override to always use an StdIndexedValue.
+    StdIndexedValue I(convOp.input());
+    IndexedValueType F(convOp.filter()), O(convOp.output());
 
     // Emit scalar form.
     Value paddedInput = getConvOpInput(convOp, I, imIdx);
@@ -344,111 +408,36 @@ class LinalgScopedEmitter<IndexedValueType, PoolingSumOp> {
   }
 };
 
-// Emits the MLIR for the scalar part of the generic op by:
-//   1. Emitting std_load and std_store ops for each input and output
-//      view in order. This is achieved by applying the appropriate input or
-//      output map to the enclosing induction variables.
-//   2. Emitting a call to `op.fun()` that takes as arguments the scalars
-//      from point 1. above.
-//   3. Emitting std_store to store the results of 2. to the output
-//      views.
-//
-// An example output may resemble:
-//
-// ```
-//    loop.for %i = %c0 to %0 step %c1 {
-//      loop.for %j = %c0 to %1 step %c1 {
-//        loop.for %k = %c0 to %4 step %c1 {
-//          %11 = load %arg0[%i, %j] :
-//            memref<?x?xf32, stride_specification>
-//          %12 = load %arg1[%i, %j, %k] :
-//            memref<?x?x?xf32, stride_specification>
-//          %13 = load %arg2[%i, %k, %j] :
-//            memref<?x?x?xf32, stride_specification>
-//          %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
-//          store %14#0, %arg1[%i, %j, %k] :
-//            memref<?x?x?Xf32, stride_specification>
-//          store %14#1, %arg2[%i, %k, %j] :
-//            memref<?x?x?Xf32, stride_specification>
-//       }
-//      }
-//    }
-// ```
-template <typename IndexedValueType>
-class LinalgScopedEmitter<IndexedValueType, GenericOp> {
-public:
-  static void emitScalarImplementation(ArrayRef<Value> allIvs,
-                                       GenericOp genericOp) {
-    assert(genericOp.hasBufferSemantics() &&
-           "expected linalg op with buffer semantics");
-    auto b = ScopedContext::getBuilder();
-    auto loc = ScopedContext::getLocation();
-    unsigned nInputs = genericOp.getNumInputs();
-    unsigned nOutputs = genericOp.getNumOutputs();
-    SmallVector<Value, 4> indexedValues(nInputs + nOutputs);
-
-    // 1.a. Emit std_load from input views.
-    for (unsigned i = 0; i < nInputs; ++i) {
-      auto indexing = makeCanonicalAffineApplies(
-          b, loc, genericOp.getInputIndexingMap(i), allIvs);
-      indexedValues[i] = std_load(genericOp.getInput(i), indexing);
-    }
-
-    // 1.b. Emit std_load from output views.
-    // TODO(mravishankar): Avoid the loads if the corresponding argument of the
-    // region has no uses.
-    for (unsigned i = 0; i < nOutputs; ++i) {
-      Value output = genericOp.getOutputBuffer(i);
-      auto indexing = makeCanonicalAffineApplies(
-          b, loc, genericOp.getOutputIndexingMap(i), allIvs);
-      indexedValues[nInputs + i] = std_load(output, indexing);
-    }
-
-    // TODO(ntv): When a region inliner exists, use it.
-    // 2. Inline region, currently only works for a single basic block.
-    // 3. Emit std_store.
-    SmallVector<SmallVector<Value, 8>, 8> indexing;
-    SmallVector<Value, 8> outputBuffers;
-    for (unsigned i = 0; i < nOutputs; ++i) {
-      indexing.push_back(makeCanonicalAffineApplies(
-          b, loc, genericOp.getOutputIndexingMap(i), allIvs));
-      outputBuffers.push_back(genericOp.getOutputBuffer(i));
-    }
-    inlineRegionAndEmitStdStore(genericOp, indexedValues, indexing,
-                                outputBuffers);
-  }
-};
-
-// Emits the MLIR for the scalar part of the indexed generic op by:
-//   1. Emitting std_load and std_store ops for each input and output view in
-//      order. This is achieved by applying the appropriate input or output map
-//      to the enclosing induction variables.
-//   2. Emitting a call to `op.fun()` that takes as arguments the induction
-//      variables and the scalars from point 1. above.
-//   3. Emitting std_store to store the results of 2. to the output views.
-//
-// An example output may resemble:
-//
-// ```
-//    loop.for %i = %c0 to %0 step %c1 {
-//      loop.for %j = %c0 to %1 step %c1 {
-//        loop.for %k = %c0 to %4 step %c1 {
-//          %11 = load %arg0[%i, %j] :
-//            memref<?x?xf32, stride_specification>
-//          %12 = load %arg1[%i, %j, %k] :
-//            memref<?x?x?xf32, stride_specification>
-//          %13 = load %arg2[%i, %k, %j] :
-//            memref<?x?x?xf32, stride_specification>
-//          %14:2 = call @foo(%i, %j, %k, %11, %12, %13) :
-//            (index, index, index, f32, f32, f32) -> (f32, f32)
-//          store %14#0, %arg1[%i, %j, %k] :
-//            memref<?x?x?Xf32, stride_specification>
-//          store %14#1, %arg2[%i, %k, %j] :
-//            memref<?x?x?Xf32, stride_specification>
-//       }
-//      }
-//    }
-// ```
+/// Emits the MLIR for the scalar part of the indexed generic op by:
+///   1. Emitting load ops for each input and output view in order. This is
+///      achieved by applying the appropriate input or output map to the
+///      enclosing induction variables.
+///   2. Emitting a call to `op.fun()` that takes as arguments the induction
+///      variables and the scalars from point 1. above.
+///   3. Emitting store ops to store the results of 2. to the output views.
+///
+/// An example output may resemble:
+///
+/// ```
+///    loop.for %i = %c0 to %0 step %c1 {
+///      loop.for %j = %c0 to %1 step %c1 {
+///        loop.for %k = %c0 to %4 step %c1 {
+///          %11 = load %arg0[%i, %j] :
+///            memref<?x?xf32, stride_specification>
+///          %12 = load %arg1[%i, %j, %k] :
+///            memref<?x?x?xf32, stride_specification>
+///          %13 = load %arg2[%i, %k, %j] :
+///            memref<?x?x?xf32, stride_specification>
+///          %14:2 = call @foo(%i, %j, %k, %11, %12, %13) :
+///            (index, index, index, f32, f32, f32) -> (f32, f32)
+///          store %14#0, %arg1[%i, %j, %k] :
+///            memref<?x?x?Xf32, stride_specification>
+///          store %14#1, %arg2[%i, %k, %j] :
+///            memref<?x?x?Xf32, stride_specification>
+///       }
+///      }
+///    }
+/// ```
 template <typename IndexedValueType>
 class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
 public:
@@ -461,31 +450,33 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
     unsigned nInputs = indexedGenericOp.getNumInputs();
     unsigned nOutputs = indexedGenericOp.getNumOutputs();
     unsigned nLoops = allIvs.size();
-    SmallVector<Value, 4> indexedValues(nLoops + nInputs + nOutputs);
-
-    for (unsigned i = 0; i < nLoops; ++i) {
-      indexedValues[i] = allIvs[i];
-    }
+    SmallVector<Value, 4> indexedValues;
+    indexedValues.reserve(nLoops + nInputs + nOutputs);
+    for (unsigned i = 0; i < nLoops; ++i)
+      indexedValues.push_back(allIvs[i]);
 
-    // 1.a. Emit std_load from input views.
+    // TODO(mravishankar): Avoid the loads if the corresponding argument of the
+    // region has no uses.
+    // 1.a. Emit load from input views.
     for (unsigned i = 0; i < nInputs; ++i) {
-      Value input = indexedGenericOp.getInput(i);
       auto indexing = makeCanonicalAffineApplies(
           b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs);
-      indexedValues[nLoops + i] = std_load(input, indexing);
+      // Pass input i through IndexedValueType emits the proper load operation.
+      indexedValues.push_back(
+          IndexedValueType(indexedGenericOp.getInput(i))(indexing));
     }
-
-    // 1.b. Emit std_load from output views.
+    // 1.b. Emit load from output views.
     for (unsigned i = 0; i < nOutputs; ++i) {
-      Value output = indexedGenericOp.getOutputBuffer(i);
       auto indexing = makeCanonicalAffineApplies(
           b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs);
-      indexedValues[nLoops + nInputs + i] = std_load(output, indexing);
+      // Pass output i through IndexedValueType emits the proper load operation.
+      indexedValues.push_back(
+          IndexedValueType(indexedGenericOp.getOutputBuffer(i))(indexing));
     }
 
     // TODO(ntv): When a region inliner exists, use it.
     // 2. Inline region, currently only works for a single basic block.
-    // 3. Emit std_store.
+    // 3. Emit store.
     SmallVector<SmallVector<Value, 8>, 8> indexing;
     SmallVector<Value, 8> outputBuffers;
     for (unsigned i = 0; i < nOutputs; ++i) {
@@ -493,19 +484,19 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
           b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
       outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i));
     }
-    inlineRegionAndEmitStdStore(indexedGenericOp, indexedValues, indexing,
-                                outputBuffers);
+    inlineRegionAndEmitStore<IndexedValueType>(indexedGenericOp, indexedValues,
+                                               indexing, outputBuffers);
   }
 };
 
-// This struct is for factoring out the implementation and support template
-// instantiations in the following 2 cases:
-//   1. Appending to a list of patterns via RewritePatternList.
-//   2. Direct invocation via `linalgOpToLoops` and `linalgOpToAffineLoops`.
-// The implementation must work both in DRR and inside a RewritePattern. As a
-// consequence, (1) it is only allowed to emit new ops if the match is
-// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an
-// encompassing pattern must take care of the erasure logic.
+/// This struct is for factoring out the implementation and support template
+/// instantiations in the following 2 cases:
+///   1. Appending to a list of patterns via RewritePatternList.
+///   2. Direct invocation via `linalgOpToLoops` and `linalgOpToAffineLoops`.
+/// The implementation must work both in DRR and inside a RewritePattern. As a
+/// consequence, (1) it is only allowed to emit new ops if the match is
+/// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an
+/// encompassing pattern must take care of the erasure logic.
 template <typename LoopTy, typename ConcreteOpTy>
 class LinalgOpToLoopsImpl {
 public:
@@ -532,7 +523,7 @@ class GenerateLoopNest {
   }
 };
 
-/// Generates loops nest using loop.parallel. loop.parallel is only used for the
+/// Generates loop nest using loop.parallel. loop.parallel is only used for the
 /// outer parallel loops. All other loops are generated using loop.for
 /// operation.
 template <typename ConcreteOpTy>
@@ -652,7 +643,7 @@ class LinalgRewritePattern : public RewritePattern {
   }
 };
 
-// Helper classes for type list expansion.
+/// Helper classes for type list expansion.
 template <typename LoopType, typename... LinalgOps>
 class RewritePatternList;
 
@@ -680,16 +671,16 @@ void FillRewritePatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) {
                      >::build(patterns, ctx);
 }
 
-// Local folding pattern for AffineApplyOp that we can apply greedily.
-// This replaces AffineApplyOp by the proper value in cases where the associated
-// map is trivial. A trivial map here is defined as a map with a single result
-// and either:
-//   1. Zero operand + returns a single AffineConstantExpr
-//   2. One operand + returns a single AffineDimExpr
-//   3. One operands + returns a single AffineSymbolExpr
+/// Local folding pattern for AffineApplyOp that we can apply greedily.
+/// This replaces AffineApplyOp by the proper value in cases where the
+/// associated map is trivial.
+/// A trivial map here is defined as a map with a single result and either:
+///   1. Zero operand + returns a single AffineConstantExpr
+///   2. One operand + returns a single AffineDimExpr
+///   3. One operand + returns a single AffineSymbolExpr
 //
-// In the first case, the AffineApplyOp is replaced by a new constant. In the
-// other cases, it is replaced by its unique operand.
+/// In the first case, the AffineApplyOp is replaced by a new constant. In the
+/// other cases, it is replaced by its unique operand.
 struct FoldAffineOp : public RewritePattern {
   FoldAffineOp(MLIRContext *context)
       : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {}

diff  --git a/mlir/test/Dialect/Linalg/affine.mlir b/mlir/test/Dialect/Linalg/affine.mlir
index 70457825ce4f..dfe130a44efb 100644
--- a/mlir/test/Dialect/Linalg/affine.mlir
+++ b/mlir/test/Dialect/Linalg/affine.mlir
@@ -1,13 +1,15 @@
 // RUN: mlir-opt %s -convert-linalg-to-affine-loops | FileCheck %s
 
 // Test that we can lower all the way to LLVM without crashing, don't check results here.
-// RUN: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1
+// RUN: mlir-opt %s -convert-linalg-to-affine-loops -convert-linalg-to-llvm -o=/dev/null 2>&1
 
 // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
 // CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
 
 // CHECK-DAG: #[[stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
 
+// CHECK-DAG: #[[clampMinMap:.*]] = affine_map<(d0) -> (d0, 0)>
+
 func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
@@ -53,3 +55,69 @@ func @conv_view3(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg1:
 //       CHECK:         affine.for %{{.*}} = 0 to %[[Q]] {
 //       CHECK:           affine.for %{{.*}} = 0 to %[[Z0]] {
 //       CHECK:            %[[SUM:.*]] = affine.apply #[[stride2Dilation1]](%{{.*}}, %{{.*}})
+
+func @conv_padding(%arg0: memref<?x?x?x?xf32>,
+                   %arg1: memref<?x?x?x?xf32>,
+                   %arg2: memref<?x?x?x?xf32>) {
+  linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1],
+                                    padding = dense<[[0, 1], [1, 1]]> : tensor<2x2xi64>,
+                                    strides = [1, 1]} :
+    memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+  return
+}
+// CHECK-LABEL: func @conv_padding
+//       CHECK: %{{.*}}: memref<?x?x?x?xf32>, %{{.*}}: memref<?x?x?x?xf32>, %{{.*}}: memref<?x?x?x?xf32>) {
+//       CHECK:   %[[ZERO:.*]] = constant 0.000000e+00 : f32
+//       CHECK:   %[[Z0:.*]] = dim %arg0, 0 : memref<?x?x?x?xf32>
+//       CHECK:   %[[Z1:.*]] = dim %arg0, 1 : memref<?x?x?x?xf32>
+//       CHECK:   %[[Q:.*]] =  dim %arg0, 2 : memref<?x?x?x?xf32>
+//       CHECK:   %[[K:.*]] =  dim %arg0, 3 : memref<?x?x?x?xf32>
+//       CHECK:   %[[B:.*]] =  dim %arg1, 0 : memref<?x?x?x?xf32>
+//       CHECK:   %[[X0:.*]] = dim %arg2, 1 : memref<?x?x?x?xf32>
+//       CHECK:   %[[X1:.*]] = dim %arg2, 2 : memref<?x?x?x?xf32>
+//       CHECK:   affine.for %{{.*}} = 0 to %[[B]] {
+//       CHECK:     affine.for %{{.*}} = 0 to %[[X0]] {
+//       CHECK:       affine.for %{{.*}} = 0 to %[[X1]] {
+//       CHECK:         affine.for %{{.*}} = 0 to %[[K]] {
+//       CHECK:           affine.for %{{.*}} = 0 to %[[Q]] {
+//       CHECK:             affine.for %{{.*}} = 0 to %[[Z0]] {
+//       CHECK:               affine.for %{{.*}} = 0 to %[[Z1]] {
+//       CHECK:                 %[[SUM0:.*]] = affine.apply #{{.*}}(%{{.*}}, %{{.*}})
+//       CHECK:                 %[[SUM1:.*]] = affine.apply #{{.*}}(%{{.*}}, %{{.*}})
+//       CHECK:                 %[[IDX:.*]] = affine.max #[[clampMinMap]](%[[SUM0]])
+//       CHECK:                 %[[IDY:.*]] = affine.max #[[clampMinMap]](%[[SUM1]])
+// Padded conv involves an affine.max in the memory access which is not
+// allowed by affine.load. Override to always use an std.load.
+//       CHECK:                 %{{.*}} = load %{{.*}}[%{{.*}}, %[[IDX]], %[[IDY]], %{{.*}}] : memref<?x?x?x?xf32>
+//       CHECK:                 %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : f32
+//       CHECK:                 %{{.*}} = affine.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<?x?x?x?xf32>
+//       CHECK:                 %{{.*}} = mulf %{{.*}}, %{{.*}} : f32
+//       CHECK:                 %{{.*}} = affine.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<?x?x?x?xf32>
+//       CHECK:                 %{{.*}} = addf %{{.*}}, %{{.*}} : f32
+//       CHECK:                 affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<?x?x?x?xf32>
+
+//----------------------------------------------------------------------------//
+// Named ops to loops.
+//----------------------------------------------------------------------------//
+func @named_batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
+  linalg.batch_matmul %A, %B, %C : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
+  return
+}
+// CHECK-LABEL: @named_batch_matmul
+//  CHECK-SAME: %[[mA:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECK-SAME: %[[mB:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECK-SAME: %[[mC:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//       CHECK: %[[B:.*]] = dim %[[mA]], 0 : memref<?x?x?xf32>
+//       CHECK: %[[M:.*]] = dim %[[mA]], 1 : memref<?x?x?xf32>
+//       CHECK: %[[K:.*]] = dim %[[mA]], 2 : memref<?x?x?xf32>
+//       CHECK: %[[N:.*]] = dim %[[mB]], 2 : memref<?x?x?xf32>
+//       CHECK: affine.for %[[b:.*]] = 0 to %[[B]] {
+//       CHECK:   affine.for %[[m:.*]] = 0 to %[[M]] {
+//       CHECK:     affine.for %[[n:.*]] = 0 to %[[N]] {
+//       CHECK:       affine.for %[[k:.*]] = 0 to %[[K]] {
+//       CHECK:       %[[va:.*]] = affine.load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref<?x?x?xf32>
+//       CHECK:       %[[vb:.*]] = affine.load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref<?x?x?xf32>
+//       CHECK:       %[[vc:.*]] = affine.load %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32>
+//       CHECK:       %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+//       CHECK:       %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECK:       affine.store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32>

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 7c71dbf893c9..6075b9824731 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -2,7 +2,7 @@
 // RUN: mlir-opt %s -convert-linalg-to-parallel-loops | FileCheck --check-prefix=CHECKPARALLEL %s
 
 // Test that we can lower all the way to LLVM without crashing, don't check results here.
-// RUN: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -o=/dev/null 2>&1
 
 // CHECKLOOP-DAG: #[[strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
 // CHECKLOOP-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
@@ -354,7 +354,6 @@ func @conv_view4(%arg0: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, %
 //       CHECKPARALLEL:           %{{.*}} = addf %{{.*}}, %{{.*}} : f32
 //       CHECKPARALLEL:           store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<?x?x?x?xf32, #[[strided4D]]>
 
-
 func @conv_padding(%arg0: memref<?x?x?x?xf32>,
                    %arg1: memref<?x?x?x?xf32>,
                    %arg2: memref<?x?x?x?xf32>) {
@@ -854,8 +853,8 @@ func @scalar_code(%arg0: memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>)
 //  CHECKLOOP-SAME: %[[ARG1]]: memref<f32>
 //  CHECKLOOP-SAME: %[[ARG2]]: memref<f32>
 //   CHECKLOOP-NOT: loop.for
-//   CHECKLOOP-DAG: load %[[ARG0]][]
-//   CHECKLOOP-DAG: load %[[ARG1]][]
+//       CHECKLOOP: load %[[ARG0]][]
+//       CHECKLOOP: load %[[ARG1]][]
 //       CHECKLOOP: addf
 //       CHECKLOOP: store %{{.*}}, %[[ARG2]][]
 
@@ -864,7 +863,50 @@ func @scalar_code(%arg0: memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>)
 //  CHECKPARALLEL-SAME: %[[ARG1]]: memref<f32>
 //  CHECKPARALLEL-SAME: %[[ARG2]]: memref<f32>
 //   CHECKPARALLEL-NOT: loop.for
-//   CHECKPARALLEL-DAG: load %[[ARG0]][]
-//   CHECKPARALLEL-DAG: load %[[ARG1]][]
+//       CHECKPARALLEL: load %[[ARG0]][]
+//       CHECKPARALLEL: load %[[ARG1]][]
 //       CHECKPARALLEL: addf
 //       CHECKPARALLEL: store %{{.*}}, %[[ARG2]][]
+
+//----------------------------------------------------------------------------//
+// Named ops to loops.
+//----------------------------------------------------------------------------//
+func @named_batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
+  linalg.batch_matmul %A, %B, %C : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
+  return
+}
+// CHECKLOOP-LABEL: @named_batch_matmul
+//  CHECKLOOP-SAME: %[[mA:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECKLOOP-SAME: %[[mB:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECKLOOP-SAME: %[[mC:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//       CHECKLOOP: %[[B:.*]] = dim %[[mA]], 0 : memref<?x?x?xf32>
+//       CHECKLOOP: %[[M:.*]] = dim %[[mA]], 1 : memref<?x?x?xf32>
+//       CHECKLOOP: %[[K:.*]] = dim %[[mA]], 2 : memref<?x?x?xf32>
+//       CHECKLOOP: %[[N:.*]] = dim %[[mB]], 2 : memref<?x?x?xf32>
+//       CHECKLOOP: loop.for %[[b:.*]] = %{{.*}} to %[[B]] step %{{.*}} {
+//       CHECKLOOP:   loop.for %[[m:.*]] = %{{.*}} to %[[M]] step %{{.*}} {
+//       CHECKLOOP:     loop.for %[[n:.*]] = %{{.*}} to %[[N]] step %{{.*}} {
+//       CHECKLOOP:       loop.for %[[k:.*]] = %{{.*}} to %[[K]] step %{{.*}} {
+//       CHECKLOOP:       %[[va:.*]] = load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref<?x?x?xf32>
+//       CHECKLOOP:       %[[vb:.*]] = load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref<?x?x?xf32>
+//       CHECKLOOP:       %[[vc:.*]] = load %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32>
+//       CHECKLOOP:       %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+//       CHECKLOOP:       %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECKLOOP:       store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32>
+
+// CHECKPARALLEL-LABEL: @named_batch_matmul
+//  CHECKPARALLEL-SAME: %[[mA:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECKPARALLEL-SAME: %[[mB:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECKPARALLEL-SAME: %[[mC:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[B:.*]] = dim %[[mA]], 0 : memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[M:.*]] = dim %[[mA]], 1 : memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[K:.*]] = dim %[[mA]], 2 : memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[N:.*]] = dim %[[mB]], 2 : memref<?x?x?xf32>
+//       CHECKPARALLEL: loop.parallel (%[[b:.*]], %[[m:.*]], %[[n:.*]]) = ({{.*}}) to (%[[B]], %[[M]], %[[N]]) step ({{.*}}) {
+//       CHECKPARALLEL:   loop.for %[[k:.*]] = %{{.*}} to %[[K]] step %{{.*}} {
+//       CHECKPARALLEL:       %[[va:.*]] = load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:       %[[vb:.*]] = load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:       %[[vc:.*]] = load %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:       %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+//       CHECKPARALLEL:       %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECKPARALLEL:       store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32>

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index 0b88f2aa11a2..d796d1917c03 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -7,15 +7,15 @@
 //  ODS-NEXT:   NamedStructuredOpTraits
 //  ODS-NEXT:   SingleBlockImplicitTerminator<"YieldOp">
 //
-// IMPL-LABEL:  Test1Op::referenceIterators() {
-//  IMPL-NEXT:  { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
+// IMPL-LABEL:  SmallVector<StringRef, 8> Test1Op::referenceIterators
+//       IMPL:  { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
 //
-//       IMPL:  Test1Op::referenceIndexingMaps() {
+//       IMPL:  SmallVector<AffineMap, 8> Test1Op::referenceIndexingMaps
 //       IMPL:  AffineMap::get(2, 0, {d0, d1}, context),
 //  IMPL-NEXT:  AffineMap::get(2, 0, {d1}, context),
 //  IMPL-NEXT:  AffineMap::get(2, 0, {d0}, context) };
 //
-//       IMPL:  Test1Op::regionBuilder(Block &block) {
+//       IMPL:  void Test1Op::regionBuilder(Block &block) {
 //       IMPL:  Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
 //       IMPL:  Value [[d:.*]] = std_mulf([[a]], [[b]]);
 //       IMPL:  Value [[e:.*]] = std_addf([[c]], [[d]]);
@@ -32,10 +32,10 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
 //  ODS-NEXT:   NamedStructuredOpTraits
 //  ODS-NEXT:   SingleBlockImplicitTerminator<"YieldOp">
 //
-// IMPL-LABEL:  Test2Op::referenceIterators() {
-//  IMPL-NEXT:  { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
+// IMPL-LABEL:  SmallVector<StringRef, 8> Test2Op::referenceIterators
+//       IMPL:  { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
 //
-//       IMPL:  Test2Op::referenceIndexingMaps() {
+//       IMPL:  SmallVector<AffineMap, 8> Test2Op::referenceIndexingMaps
 //       IMPL:  AffineMap::get(3, 0, {d0, d2}, context),
 //  IMPL-NEXT:  AffineMap::get(3, 0, {d2, d1}, context),
 //  IMPL-NEXT:  AffineMap::get(3, 0, {d0, d1}, context) };
@@ -57,10 +57,10 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
 //  ODS-NEXT:   NamedStructuredOpTraits
 //  ODS-NEXT:   SingleBlockImplicitTerminator<"YieldOp">
 //
-// IMPL-LABEL:  Test3Op::referenceIterators() {
-//  IMPL-NEXT:  { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
+// IMPL-LABEL:  SmallVector<StringRef, 8> Test3Op::referenceIterators
+//       IMPL:  { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
 //
-//       IMPL:  Test3Op::referenceIndexingMaps() {
+//       IMPL:  SmallVector<AffineMap, 8> Test3Op::referenceIndexingMaps
 //       IMPL:  AffineMap::get(4, 0, {d0, d1, d3}, context),
 //  IMPL-NEXT:  AffineMap::get(4, 0, {d3, d2}, context),
 //  IMPL-NEXT:  AffineMap::get(4, 0, {d0, d1, d2}, context) };

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 424a29716368..d2dd1f5d9738 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1472,7 +1472,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
         [{{
           result.addOperands(views);
           result.addTypes(outputTypes);
-          buildNamedStructuredOpRegion<{0}>(
+          buildNamedStructuredOpRegionAndAttributes<{0}>(
             b, result, TypeRange(views), outputTypes);
         }]>
       ];
@@ -1481,7 +1481,13 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
       }];
       let extraClassDeclaration = [{{
         llvm::Optional<SmallVector<StringRef, 8>> referenceIterators();
+        static SmallVector<StringRef, 8> referenceIterators(
+          TypeRange inputTypes, TypeRange outputTypes);
+
         llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps();
+        static SmallVector<AffineMap, 8> referenceIndexingMaps(
+          TypeRange inputTypes, TypeRange outputTypes);
+
         static void regionBuilder(Block &block);
       }];
   })FMT";
@@ -1503,7 +1509,13 @@ void TCParser::printReferenceIterators(llvm::raw_ostream &os,
                                        ComprehensionParsingState &state) {
   const char *referenceReferenceIteratorsFmt =
       R"FMT(
-    llvm::Optional<SmallVector<StringRef, 8>> {0}::referenceIterators() {
+    // This is temporary until we transition out of manually specified ops
+    // that should be auto-generated with linalg-ods-gen.
+    llvm::Optional<SmallVector<StringRef, 8>> {0}::referenceIterators() {{
+      llvm_unreachable("Unexpected missing `iterator_types` attribute.");
+    }
+    SmallVector<StringRef, 8> {0}::referenceIterators(
+      TypeRange inputTypes, TypeRange outputTypes) {
       return SmallVector<StringRef, 8>{{ {1} };
     })FMT";
 
@@ -1536,15 +1548,27 @@ void TCParser::printReferenceIterators(llvm::raw_ostream &os,
 void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
                                           StringRef cppOpName,
                                           ComprehensionParsingState &state) {
+  // 1. Generic string template for specifying reference indexing maps.
   const char *referenceIndexingMapsFmt =
       R"FMT(
-  llvm::Optional<SmallVector<AffineMap, 8>> {0}::referenceIndexingMaps() {
-    MLIRContext *context = getContext();
+  // This is temporary until we transition out of manually specified ops that
+  // should be auto-generated with linalg-ods-gen.
+  llvm::Optional<SmallVector<AffineMap, 8>> {0}::referenceIndexingMaps() {{
+    llvm_unreachable("Unexpected missing `indexing_maps` attribute.");
+  }
+  SmallVector<AffineMap, 8> {0}::referenceIndexingMaps(
+    TypeRange inputTypes, TypeRange outputTypes) {
+    assert(!inputTypes.empty() && "At least one input expected");
+    MLIRContext *context = (*inputTypes.begin()).getContext();
     AffineExpr {1};
     bindDims(context, {1});
     return SmallVector<AffineMap, 8>{{ {2} };
   })FMT";
 
+  // 2. Print a comma-separated list of identifiers for the AffineExpr in
+  // `state.dims`. These will replace the `{1}` placeholder in both
+  // `AffineExpr {1}` and `bindDims(context, {1})` ensuring the AffineExpr
+  // identifiers are bound in the right order to the proper AffineDimExpr.
   std::string dimsStr;
   llvm::raw_string_ostream ss(dimsStr);
   llvm::interleaveComma(
@@ -1552,10 +1576,14 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
       [&](std::pair<StringRef, AffineExpr> p) { ss << p.second; });
   ss.flush();
 
+  // 3. Print a comma-separated list of AffineMap constructors that use the
+  // identifiers from 1. The AffineExpr use the common arithmetic operators on
+  // AffineExpr. These AffineMap constructors will replace the `{2}` placeholder
+  // in return `SmallVector<AffineMap, 8>{{ {2} };`.
   std::string mapsStr;
   llvm::raw_string_ostream mapsStringStream(mapsStr);
   SmallVector<TensorUse, 4> orderedUses(state.orderedTensorArgs.size());
-  for (auto it : state.orderedTensorArgs)
+  for (const auto &it : state.orderedTensorArgs)
     orderedUses[it.second] = it.first;
   llvm::interleaveComma(orderedUses, mapsStringStream, [&](TensorUse u) {
     assert(u.indexingMap);
@@ -1576,6 +1604,7 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
   });
   mapsStringStream.flush();
 
+  // 4. Apply format to 1. using 2. and 3.
   os << llvm::formatv(referenceIndexingMapsFmt, cppOpName, dimsStr, mapsStr);
 }
 


        


More information about the Mlir-commits mailing list