[llvm-branch-commits] [mlir] fa8c397 - [mlir][Linalg] NFC: Refactor fusion of LinalgOp with TensorReshapeOp by expansion.

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Jan 8 12:02:54 PST 2021


Author: MaheshRavishankar
Date: 2021-01-08T11:58:19-08:00
New Revision: fa8c397dfa2ac2490236110a597d6aa764f41da4

URL: https://github.com/llvm/llvm-project/commit/fa8c397dfa2ac2490236110a597d6aa764f41da4
DIFF: https://github.com/llvm/llvm-project/commit/fa8c397dfa2ac2490236110a597d6aa764f41da4.diff

LOG: [mlir][Linalg] NFC: Refactor fusion of LinalgOp with TensorReshapeOp by expansion.

Change the implementation of LinalgOp with TensorReshapeOp by
expansion to be more modular and easier to follow.

Differential Revision: https://reviews.llvm.org/D93748

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index b1ac1a3b48b6..4075ddd12117 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -56,6 +56,7 @@ SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
 LoopRangeBuilder defaultLoopRangesBuilder(LinalgOp op);
 
 using ReassociationIndices = SmallVector<int64_t, 2>;
+using ReassociationIndicesRef = ArrayRef<int64_t>;
 using ReassociationExprs = SmallVector<AffineExpr, 2>;
 
 /// Returns the name mangled library call name to disambiguate between 
diff erent

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index b1ea07309b4f..37062ac33e2b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -431,22 +431,160 @@ static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
          });
 }
 
-// Get the output tensor to use for the expanded operation. Creates an
-// `linalg.init_tensor` operation to materialize the tensor that carries the
-// shape information.
-static Value getOutputValueForExpansion(
-    OpBuilder &builder, Location loc, AffineMap outputIndexingMap, Value result,
-    ArrayRef<SmallVector<int64_t, 4>> origDimToExpandedShapeMap) {
+namespace {
+/// Information needed to expand a generic/indexed_generic operation to fold the
+/// reshape with it.
+class ExpansionInfo {
+public:
+  // Computes the mapping from original dimensions of the op to the dimensions
+  // of the expanded op given the `indexingMap` of the fused operand/result of
+  // the generic/indexed_generic op, the `reassocationMaps` of the reshape op
+  // and the shape of the expanded op.
+  LogicalResult compute(LinalgOp linalgOp, unsigned fusedTensorIndex,
+                        ArrayRef<AffineMap> reassociationMaps,
+                        ArrayRef<int64_t> expandedShape);
+  unsigned getOrigOpNumDims() const { return reassociation.size(); }
+  unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
+  ReassociationIndicesRef getExpandedDims(unsigned i) const {
+    return reassociation[i];
+  }
+  ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
+    return expandedShapeMap[i];
+  }
+
+private:
+  /// Reassociation from the dimensions in the original operation to the
+  /// dimension of the expanded operation.
+  SmallVector<ReassociationIndices, 4> reassociation;
+  /// Mapping from extent of loops in the original operation, to the extent of
+  /// loops in the expanded operation.
+  SmallVector<SmallVector<int64_t, 4>, 4> expandedShapeMap;
+  unsigned expandedOpNumDims;
+};
+} // namespace
+
+LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
+                                     unsigned fusedTensorIndex,
+                                     ArrayRef<AffineMap> reassociationMaps,
+                                     ArrayRef<int64_t> expandedShape) {
+  if (reassociationMaps.empty())
+    return failure();
+  AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex);
+
+  Optional<SmallVector<int64_t, 4>> originalLoopRange =
+      getStaticLoopRanges(linalgOp);
+  if (!originalLoopRange)
+    return linalgOp.emitError("unable to find loop range for operation");
+
+  reassociation.clear();
+  expandedShapeMap.clear();
+  // Compute the number of dimension in the expanded op that correspond to each
+  // dimension of the original op.
+  SmallVector<unsigned, 4> numExpandedDims(fusedIndexMap.getNumDims(), 1);
+  expandedShapeMap.resize(fusedIndexMap.getNumDims());
+  for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
+    unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
+    AffineMap foldedDims = reassociationMaps[resultExpr.index()];
+    numExpandedDims[pos] = foldedDims.getNumResults();
+    ArrayRef<int64_t> shape =
+        expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
+    expandedShapeMap[pos].assign(shape.begin(), shape.end());
+  }
+  // The remaining dimensions remain the same.
+  for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
+    if (expandedShapeMap[i].empty())
+      expandedShapeMap[i] = {(*originalLoopRange)[i]};
+
+  // Compute reassociation map from the original op to the expanded op.
+  unsigned sum = 0;
+  reassociation.reserve(fusedIndexMap.getNumDims());
+  for (auto numFoldedDim : llvm::enumerate(numExpandedDims)) {
+    auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
+    reassociation.emplace_back(seq.begin(), seq.end());
+    sum += numFoldedDim.value();
+  }
+  expandedOpNumDims = sum;
+  return success();
+}
+
+/// To expand an indexed_generic operation, the body of the indexed generic op
+/// need to be modified appropriately. Specifically, uses of arguments for
+/// induction variables in the original operation need to be replaced with
+/// linearization of the corresponding arguments in the expanded op. That
+/// requires the shape of the expanded dimensions (at least all but the most
+/// significant. For now check that these are all statically sized. Note that
+/// this could be extended to handle dynamic case, but the implementation below
+/// uses `affine.apply` which seems to have issues when the shapes are not
+/// static.
+LogicalResult isIndexedGenericOpExpandable(LinalgOp linalgOp,
+                                           const ExpansionInfo &expansionInfo) {
+  for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
+    ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
+    if (expandedShape.size() == 1)
+      continue;
+    for (int64_t shape : expandedShape.drop_front()) {
+      if (ShapedType::isDynamic(shape)) {
+        return linalgOp.emitError(
+            "unable to fuse indexed generic op where the expanded dim is "
+            "dynamic");
+      }
+    }
+  }
+  return success();
+}
+
+/// Return the indexing map to use in the expanded op for a given the
+/// `indexingMap` of the original operation.
+static AffineMap
+getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
+                           const ExpansionInfo &expansionInfo) {
+  SmallVector<AffineExpr, 4> newExprs;
+  for (AffineExpr expr : indexingMap.getResults()) {
+    unsigned pos = expr.cast<AffineDimExpr>().getPosition();
+    SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
+        llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
+          return builder.getAffineDimExpr(static_cast<unsigned>(v));
+        }));
+    newExprs.append(expandedExprs.begin(), expandedExprs.end());
+  }
+  return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
+                        indexingMap.getNumSymbols(), newExprs,
+                        builder.getContext());
+}
+
+/// Return the type of the operand/result to use in the expanded op given the
+/// type in the original op.
+static RankedTensorType getExpandedType(RankedTensorType originalType,
+                                        AffineMap indexingMap,
+                                        const ExpansionInfo &expansionInfo) {
+  SmallVector<int64_t, 4> expandedShape;
+  for (AffineExpr expr : indexingMap.getResults()) {
+    unsigned dim = expr.cast<AffineDimExpr>().getPosition();
+    auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
+    expandedShape.append(dimExpansion.begin(), dimExpansion.end());
+  }
+  return RankedTensorType::get(expandedShape, originalType.getElementType());
+}
+
+/// Get the value to use for the output in the expanded operation given the
+/// `indexingMap` for the output in the original op. Creates an
+/// `linalg.init_tensor` operation to materialize the tensor that carries the
+/// shape information. This is only used when the tensor_reshape is expanding
+/// and is a consumer. In such cases, the tensor_reshape op semantics gaurantees
+/// that the shape of the output is computable from the shape of the input since
+/// at most one of the expanded dims can be dynamic.
+static Value getOutputValueForExpandedOp(OpBuilder &builder, Location loc,
+                                         AffineMap indexingMap, Value result,
+                                         const ExpansionInfo &expansionInfo) {
   SmallVector<Value, 4> dynamicDims;
   SmallVector<int64_t, 4> staticDims;
   ShapedType resultType = result.getType().cast<ShapedType>();
   ArrayRef<int64_t> origShape = resultType.getShape();
-  for (AffineExpr expr : outputIndexingMap.getResults()) {
+  for (AffineExpr expr : indexingMap.getResults()) {
     unsigned origDimPos = expr.cast<AffineDimExpr>().getPosition();
-    ArrayRef<int64_t> expandedShape(origDimToExpandedShapeMap[origDimPos]);
     bool foundDynamic = false;
     int64_t linearizedShape = 1;
-    for (int64_t extent : expandedShape) {
+    for (int64_t extent : expansionInfo.getExpandedShapeOfDim(origDimPos)) {
       if (ShapedType::isDynamic(extent)) {
         assert(!foundDynamic &&
                "Expanded dimensions of reshape can have only one dynamic dim");
@@ -467,6 +605,79 @@ static Value getOutputValueForExpansion(
                                               resultType.getElementType());
 }
 
+/// Returns the reassociation maps to use in the `linalg.tensor_reshape`
+/// operation to convert the operands of the origial operation to operands of
+/// the expanded operation. The same method is used to compute the
+/// `linalg.tensor_reshape` used to collapse the result of the expanded op to
+/// get the value that can replace all uses of the results of the original op.
+static SmallVector<ReassociationIndices, 4>
+getReassociationForExpansion(AffineMap indexingMap,
+                             const ExpansionInfo &expansionInfo) {
+  SmallVector<ReassociationIndices, 4> reassociation;
+  unsigned numReshapeDims = 0;
+  for (AffineExpr expr : indexingMap.getResults()) {
+    unsigned dim = expr.cast<AffineDimExpr>().getPosition();
+    auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
+    auto indices = llvm::to_vector<2>(
+        llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
+    reassociation.emplace_back(std::move(indices));
+    numReshapeDims += numExpandedDims;
+  }
+  return reassociation;
+}
+
+/// Build the body of the expanded IndexedGenericOp. The arguments for the
+/// induction variables of the original operation need to be recovered by
+/// linearizing the arguments of the corresponding dimensions of the expanded
+/// op. For now it is assumed that the shapes of the expanded op needed for
+/// linearization are static.
+static void buildExpandedIndexedGenericOpRegion(
+    PatternRewriter &rewriter, Location loc, Region &originalOpRegion,
+    Region &fusedOpRegion, const ExpansionInfo &expansionInfo) {
+  assert(fusedOpRegion.empty() && "expected fused op to have empty region");
+  // Create an entry block in the fused region with same number of arguments
+  // as the fused op
+  Block *fusedEntryBlock = new Block;
+  fusedOpRegion.push_back(fusedEntryBlock);
+  rewriter.cloneRegionBefore(originalOpRegion, fusedOpRegion,
+                             fusedOpRegion.end());
+
+  // Merge the entry block of the fused op with the cloned blocks. For this
+  // compute the value for arguments of the region in the original operation
+  // in terms of the arguments of the fused op. Since the original operation
+  // is expanded, the expanded dimensions need to be folded back to get the
+  // replacement value for the arguments corresponding to interation index.
+  // For now this expects that all the loop ranges are constants, which is
+  // true if the shapes are all static. This has already been checked in the
+  // precondition.
+  using namespace edsc::op;
+  using namespace edsc::intrinsics;
+  OpBuilder::InsertionGuard guard(rewriter);
+  SmallVector<Value, 4> argReplacements(originalOpRegion.getNumArguments());
+  rewriter.setInsertionPointToStart(fusedEntryBlock);
+  edsc::ScopedContext scopedContext(rewriter, loc);
+  IndexType indexType = rewriter.getIndexType();
+  for (auto i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
+    Value linearizedIndex = fusedEntryBlock->addArgument(indexType);
+    ArrayRef<int64_t> expandedDimsShape =
+        expansionInfo.getExpandedShapeOfDim(i).drop_front();
+    for (unsigned shape : expandedDimsShape) {
+      assert(!ShapedType::isDynamic(shape));
+      linearizedIndex = linearizedIndex * std_constant_index(shape);
+      linearizedIndex =
+          linearizedIndex + fusedEntryBlock->addArgument(indexType);
+    }
+    argReplacements[i] = linearizedIndex;
+  }
+  for (auto i : llvm::seq<unsigned>(expansionInfo.getOrigOpNumDims(),
+                                    argReplacements.size())) {
+    argReplacements[i] =
+        fusedEntryBlock->addArgument(originalOpRegion.getArgument(i).getType());
+  }
+  rewriter.mergeBlocks(fusedEntryBlock->getNextNode(), fusedEntryBlock,
+                       argReplacements);
+}
+
 /// Implements the fusion of a tensor_reshape op and a generic/indexed_generic
 /// op as explained in `isFusableWithReshapeByExpansion`. Assumes that those
 /// conditions have been satisfied.
@@ -481,104 +692,22 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
       reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank();
   RankedTensorType expandedType =
       isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType();
-  AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex);
 
-  // The reshape is folding/expanding consecutive dimensions. Given the indexing
-  // map of the fused tensor find the number of dimensions each of the loops of
-  // the original op is expanded into. Also record the shape of the expanded
-  // dimensions.
-  ArrayRef<int64_t> expandedShape = expandedType.getShape();
-  Optional<SmallVector<int64_t, 4>> origOpLoopRange =
-      getStaticLoopRanges(linalgOp);
-  if (!origOpLoopRange) {
-    linalgOp.emitError("unable to find loop range for operation");
+  ExpansionInfo expansionInfo;
+  if (failed(expansionInfo.compute(linalgOp, fusedTensorIndex,
+                                   reshapeOp.getReassociationMaps(),
+                                   expandedType.getShape())))
     return llvm::None;
-  }
-  SmallVector<unsigned, 4> numFoldedDims(fusedIndexMap.getNumDims(), 1);
-  SmallVector<SmallVector<int64_t, 4>, 4> expandedDimsShape(
-      fusedIndexMap.getNumDims());
-  auto reassociationMaps = reshapeOp.getReassociationMaps();
-  for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
-    unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
-    AffineMap foldedDims = reassociationMaps[resultExpr.index()];
-    numFoldedDims[pos] = foldedDims.getNumResults();
-    ArrayRef<int64_t> shape =
-        expandedShape.slice(foldedDims.getDimPosition(0), numFoldedDims[pos]);
-    expandedDimsShape[pos].assign(shape.begin(), shape.end());
-  }
-  // The remaining dimensions remain the same.
-  for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
-    if (expandedDimsShape[i].empty())
-      expandedDimsShape[i] = {(*origOpLoopRange)[i]};
-
-  if (isa<IndexedGenericOp>(linalgOp.getOperation())) {
-    // For indexed generic op, the region contains arguments that represent the
-    // induction variable value of the loops. In the fused op these values are
-    // obtained by linearizing the expanded dimensions. For now just check that
-    // the extents used in the linearization (all the expanded dims except the
-    // front) are statically know. For dynamic case, we would need shape
-    // information on these dimensions to get these.
-    for (auto &expandedShape : expandedDimsShape) {
-      if (expandedShape.size() == 1)
-        continue;
-      for (int64_t expandedDimShape : llvm::make_range(
-               std::next(expandedShape.begin()), expandedShape.end())) {
-        if (ShapedType::isDynamic(expandedDimShape)) {
-          linalgOp.emitError(
-              "unable to fuse indexed generic op where the expanded dim is "
-              "dynamic");
-          return llvm::None;
-        }
-      }
-    }
-  }
 
-  // The remapping of the indices is then the prefix sum (inclusive) of the
-  // numFoldedDims.
-  SmallVector<unsigned, 4> remapping(numFoldedDims.size() + 1, 0);
-  unsigned sum = 0;
-  for (auto numFoldedDim : llvm::enumerate(numFoldedDims)) {
-    sum += numFoldedDim.value();
-    remapping[numFoldedDim.index() + 1] = sum;
-  }
+  if (isa<IndexedGenericOp>(linalgOp.getOperation()) &&
+      failed(isIndexedGenericOpExpandable(linalgOp, expansionInfo)))
+    return llvm::None;
 
-  SmallVector<AffineMap, 4> expandedOpIndexingMaps;
-  // Compute the modified indexing maps by replacing every loop (AffineDimExpr)
-  // in the original indexing map with the sequence of loops that it is expanded
-  // to.
-  for (AffineMap indexingMap : linalgOp.getIndexingMaps()) {
-    SmallVector<AffineExpr, 4> newExprs;
-    for (AffineExpr expr : indexingMap.getResults()) {
-      unsigned pos = expr.cast<AffineDimExpr>().getPosition();
-      for (unsigned newPos :
-           llvm::seq<unsigned>(remapping[pos], remapping[pos + 1])) {
-        newExprs.push_back(rewriter.getAffineDimExpr(newPos));
-      }
-    }
-    expandedOpIndexingMaps.push_back(
-        AffineMap::get(remapping.back(), indexingMap.getNumSymbols(), newExprs,
-                       rewriter.getContext()));
-  }
+  SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
+      llvm::map_range(linalgOp.getIndexingMaps(), [&](AffineMap m) {
+        return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
+      }));
 
-  // The operands of the expanded op are computed by reshaping the original
-  // operands. The reshape depends on the ordering of the loop used to access
-  // the tensor in the original operation, and are expanded into as many
-  // dimensions as the loop is expanded into (as computed by `remapping`).
-  auto getReshapeInfo =
-      [&](AffineMap operandIndexingMap,
-          SmallVectorImpl<ReassociationIndices> &reassociation,
-          SmallVectorImpl<int64_t> &expandedOpOperandShape) {
-        unsigned reshapeDims = 0;
-        for (AffineExpr expr : operandIndexingMap.getResults()) {
-          unsigned origDim = expr.cast<AffineDimExpr>().getPosition();
-          auto foldedDims = llvm::seq<int64_t>(
-              reshapeDims, reshapeDims + numFoldedDims[origDim]);
-          reassociation.emplace_back(foldedDims.begin(), foldedDims.end());
-          expandedOpOperandShape.append(expandedDimsShape[origDim].begin(),
-                                        expandedDimsShape[origDim].end());
-          reshapeDims += numFoldedDims[origDim];
-        }
-      };
   SmallVector<Value, 4> expandedOpOperands;
   for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
     if (operand.index() == fusedTensorIndex) {
@@ -586,36 +715,31 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
       continue;
     }
     AffineMap indexingMap = linalgOp.getInputIndexingMap(operand.index());
-    SmallVector<ReassociationIndices, 4> reassociation;
-    SmallVector<int64_t, 4> expandedOperandShape;
-    getReshapeInfo(indexingMap, reassociation, expandedOperandShape);
-    Type expandedOperandType = RankedTensorType::get(
-        expandedOperandShape,
-        operand.value().getType().cast<ShapedType>().getElementType());
+    RankedTensorType expandedOperandType =
+        getExpandedType(operand.value().getType().cast<RankedTensorType>(),
+                        indexingMap, expansionInfo);
     if (expandedOperandType != operand.value().getType()) {
+      // Reshape the operand to get the right type.
+      SmallVector<ReassociationIndices, 4> reassociation =
+          getReassociationForExpansion(indexingMap, expansionInfo);
       expandedOpOperands.push_back(rewriter.create<TensorReshapeOp>(
           linalgOp.getLoc(), expandedOperandType, operand.value(),
           reassociation));
-    } else {
-      expandedOpOperands.push_back(operand.value());
+      continue;
     }
+    expandedOpOperands.push_back(operand.value());
   }
 
   Location loc = linalgOp.getLoc();
   SmallVector<Value, 1> outputs;
-  SmallVector<SmallVector<ReassociationIndices, 4>, 1> resultReassociation;
   for (auto result : llvm::enumerate(linalgOp.getOutputs())) {
     AffineMap indexingMap = linalgOp.getOutputIndexingMap(result.index());
-    SmallVector<ReassociationIndices, 4> reassociation;
-    SmallVector<int64_t, 4> expandedResultShape;
-    getReshapeInfo(indexingMap, reassociation, expandedResultShape);
-    outputs.push_back(getOutputValueForExpansion(
-        rewriter, loc, indexingMap, result.value(), expandedDimsShape));
-    resultReassociation.emplace_back(std::move(reassociation));
+    outputs.push_back(getOutputValueForExpandedOp(
+        rewriter, loc, indexingMap, result.value(), expansionInfo));
   }
 
   // The iterator types of the expanded op are all parallel.
-  SmallVector<StringRef, 4> iteratorTypes(remapping.back(),
+  SmallVector<StringRef, 4> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
                                           getParallelIteratorTypeName());
 
   TypeRange resultTypes = ValueRange(outputs).getTypes();
@@ -631,48 +755,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
                                fusedRegion.begin());
   } else {
     assert(isa<IndexedGenericOp>(linalgOp.getOperation()));
-    // Create an entry block in the fused Region with same number of arguments
-    // as the fused op
-    Block *fusedEntryBlock = new Block;
-    fusedRegion.push_back(fusedEntryBlock);
-    rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.end());
-
-    // Merge the entry block of the fused op with the cloned blocks. For this
-    // compute the value for arguments of the region in the original operation
-    // in terms of the arguments of the fused op. Since the original operation
-    // is expanded, the expanded dimensions need to be folded back to get the
-    // replacement value for the arguments corresponding to interation index.
-    // For now this expects that all the loop ranges are constants, which is
-    // true if the shapes are all static. This has already been checked in the
-    // precondition.
-    using namespace edsc::op;
-    using namespace edsc::intrinsics;
-    OpBuilder::InsertionGuard guard(rewriter);
-    SmallVector<Value, 4> argReplacements(originalRegion.getNumArguments());
-    rewriter.setInsertionPointToStart(fusedEntryBlock);
-    edsc::ScopedContext scopedContext(rewriter, fusedOp.getLoc());
-    IndexType indexType = rewriter.getIndexType();
-    for (unsigned i : llvm::seq<unsigned>(0, numFoldedDims.size())) {
-      Value linearizedIndex = fusedEntryBlock->addArgument(indexType);
-      for (unsigned foldedDim = remapping[i] + 1; foldedDim != remapping[i + 1];
-           foldedDim++) {
-        int64_t expandedDimExtent =
-            expandedDimsShape[i][foldedDim - remapping[i]];
-        assert(!ShapedType::isDynamic(expandedDimExtent));
-        linearizedIndex =
-            linearizedIndex * std_constant_index(expandedDimExtent);
-        linearizedIndex =
-            linearizedIndex + fusedEntryBlock->addArgument(indexType);
-      }
-      argReplacements[i] = linearizedIndex;
-    }
-    for (unsigned i :
-         llvm::seq<unsigned>(numFoldedDims.size(), argReplacements.size())) {
-      argReplacements[i] =
-          fusedEntryBlock->addArgument(originalRegion.getArgument(i).getType());
-    }
-    rewriter.mergeBlocks(fusedEntryBlock->getNextNode(), fusedEntryBlock,
-                         argReplacements);
+    buildExpandedIndexedGenericOpRegion(rewriter, loc, originalRegion,
+                                        fusedRegion, expansionInfo);
   }
 
   // Reshape the result values to their original shape if this is a collapsing
@@ -681,10 +765,12 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
   for (auto result : llvm::enumerate(linalgOp->getResults())) {
     if (!isExpanding &&
         resultTypes[result.index()] != result.value().getType()) {
+      SmallVector<ReassociationIndices, 4> reassociation =
+          getReassociationForExpansion(
+              linalgOp.getOutputIndexingMap(result.index()), expansionInfo);
       resultVals.push_back(rewriter.create<TensorReshapeOp>(
           linalgOp.getLoc(), result.value().getType(),
-          fusedOp->getResult(result.index()),
-          resultReassociation[result.index()]));
+          fusedOp->getResult(result.index()), reassociation));
     } else {
       resultVals.push_back(fusedOp->getResult(result.index()));
     }


        


More information about the llvm-branch-commits mailing list