[llvm-branch-commits] [mlir] fa8c397 - [mlir][Linalg] NFC: Refactor fusion of LinalgOp with TensorReshapeOp by expansion.
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Jan 8 12:02:54 PST 2021
Author: MaheshRavishankar
Date: 2021-01-08T11:58:19-08:00
New Revision: fa8c397dfa2ac2490236110a597d6aa764f41da4
URL: https://github.com/llvm/llvm-project/commit/fa8c397dfa2ac2490236110a597d6aa764f41da4
DIFF: https://github.com/llvm/llvm-project/commit/fa8c397dfa2ac2490236110a597d6aa764f41da4.diff
LOG: [mlir][Linalg] NFC: Refactor fusion of LinalgOp with TensorReshapeOp by expansion.
Change the implementation of LinalgOp with TensorReshapeOp by
expansion to be more modular and easier to follow.
Differential Revision: https://reviews.llvm.org/D93748
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index b1ac1a3b48b6..4075ddd12117 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -56,6 +56,7 @@ SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
LoopRangeBuilder defaultLoopRangesBuilder(LinalgOp op);
using ReassociationIndices = SmallVector<int64_t, 2>;
+using ReassociationIndicesRef = ArrayRef<int64_t>;
using ReassociationExprs = SmallVector<AffineExpr, 2>;
/// Returns the name mangled library call name to disambiguate between
diff erent
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index b1ea07309b4f..37062ac33e2b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -431,22 +431,160 @@ static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
});
}
-// Get the output tensor to use for the expanded operation. Creates an
-// `linalg.init_tensor` operation to materialize the tensor that carries the
-// shape information.
-static Value getOutputValueForExpansion(
- OpBuilder &builder, Location loc, AffineMap outputIndexingMap, Value result,
- ArrayRef<SmallVector<int64_t, 4>> origDimToExpandedShapeMap) {
+namespace {
+/// Information needed to expand a generic/indexed_generic operation to fold the
+/// reshape with it.
+class ExpansionInfo {
+public:
+ // Computes the mapping from original dimensions of the op to the dimensions
+ // of the expanded op given the `indexingMap` of the fused operand/result of
+ // the generic/indexed_generic op, the `reassocationMaps` of the reshape op
+ // and the shape of the expanded op.
+ LogicalResult compute(LinalgOp linalgOp, unsigned fusedTensorIndex,
+ ArrayRef<AffineMap> reassociationMaps,
+ ArrayRef<int64_t> expandedShape);
+ unsigned getOrigOpNumDims() const { return reassociation.size(); }
+ unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
+ ReassociationIndicesRef getExpandedDims(unsigned i) const {
+ return reassociation[i];
+ }
+ ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
+ return expandedShapeMap[i];
+ }
+
+private:
+ /// Reassociation from the dimensions in the original operation to the
+ /// dimension of the expanded operation.
+ SmallVector<ReassociationIndices, 4> reassociation;
+ /// Mapping from extent of loops in the original operation, to the extent of
+ /// loops in the expanded operation.
+ SmallVector<SmallVector<int64_t, 4>, 4> expandedShapeMap;
+ unsigned expandedOpNumDims;
+};
+} // namespace
+
+LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
+ unsigned fusedTensorIndex,
+ ArrayRef<AffineMap> reassociationMaps,
+ ArrayRef<int64_t> expandedShape) {
+ if (reassociationMaps.empty())
+ return failure();
+ AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex);
+
+ Optional<SmallVector<int64_t, 4>> originalLoopRange =
+ getStaticLoopRanges(linalgOp);
+ if (!originalLoopRange)
+ return linalgOp.emitError("unable to find loop range for operation");
+
+ reassociation.clear();
+ expandedShapeMap.clear();
+ // Compute the number of dimension in the expanded op that correspond to each
+ // dimension of the original op.
+ SmallVector<unsigned, 4> numExpandedDims(fusedIndexMap.getNumDims(), 1);
+ expandedShapeMap.resize(fusedIndexMap.getNumDims());
+ for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
+ unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
+ AffineMap foldedDims = reassociationMaps[resultExpr.index()];
+ numExpandedDims[pos] = foldedDims.getNumResults();
+ ArrayRef<int64_t> shape =
+ expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
+ expandedShapeMap[pos].assign(shape.begin(), shape.end());
+ }
+ // The remaining dimensions remain the same.
+ for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
+ if (expandedShapeMap[i].empty())
+ expandedShapeMap[i] = {(*originalLoopRange)[i]};
+
+ // Compute reassociation map from the original op to the expanded op.
+ unsigned sum = 0;
+ reassociation.reserve(fusedIndexMap.getNumDims());
+ for (auto numFoldedDim : llvm::enumerate(numExpandedDims)) {
+ auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
+ reassociation.emplace_back(seq.begin(), seq.end());
+ sum += numFoldedDim.value();
+ }
+ expandedOpNumDims = sum;
+ return success();
+}
+
+/// To expand an indexed_generic operation, the body of the indexed generic op
+/// need to be modified appropriately. Specifically, uses of arguments for
+/// induction variables in the original operation need to be replaced with
+/// linearization of the corresponding arguments in the expanded op. That
+/// requires the shape of the expanded dimensions (at least all but the most
+/// significant. For now check that these are all statically sized. Note that
+/// this could be extended to handle dynamic case, but the implementation below
+/// uses `affine.apply` which seems to have issues when the shapes are not
+/// static.
+LogicalResult isIndexedGenericOpExpandable(LinalgOp linalgOp,
+ const ExpansionInfo &expansionInfo) {
+ for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
+ ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
+ if (expandedShape.size() == 1)
+ continue;
+ for (int64_t shape : expandedShape.drop_front()) {
+ if (ShapedType::isDynamic(shape)) {
+ return linalgOp.emitError(
+ "unable to fuse indexed generic op where the expanded dim is "
+ "dynamic");
+ }
+ }
+ }
+ return success();
+}
+
+/// Return the indexing map to use in the expanded op for a given the
+/// `indexingMap` of the original operation.
+static AffineMap
+getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
+ const ExpansionInfo &expansionInfo) {
+ SmallVector<AffineExpr, 4> newExprs;
+ for (AffineExpr expr : indexingMap.getResults()) {
+ unsigned pos = expr.cast<AffineDimExpr>().getPosition();
+ SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
+ llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
+ return builder.getAffineDimExpr(static_cast<unsigned>(v));
+ }));
+ newExprs.append(expandedExprs.begin(), expandedExprs.end());
+ }
+ return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
+ indexingMap.getNumSymbols(), newExprs,
+ builder.getContext());
+}
+
+/// Return the type of the operand/result to use in the expanded op given the
+/// type in the original op.
+static RankedTensorType getExpandedType(RankedTensorType originalType,
+ AffineMap indexingMap,
+ const ExpansionInfo &expansionInfo) {
+ SmallVector<int64_t, 4> expandedShape;
+ for (AffineExpr expr : indexingMap.getResults()) {
+ unsigned dim = expr.cast<AffineDimExpr>().getPosition();
+ auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
+ expandedShape.append(dimExpansion.begin(), dimExpansion.end());
+ }
+ return RankedTensorType::get(expandedShape, originalType.getElementType());
+}
+
+/// Get the value to use for the output in the expanded operation given the
+/// `indexingMap` for the output in the original op. Creates an
+/// `linalg.init_tensor` operation to materialize the tensor that carries the
+/// shape information. This is only used when the tensor_reshape is expanding
+/// and is a consumer. In such cases, the tensor_reshape op semantics gaurantees
+/// that the shape of the output is computable from the shape of the input since
+/// at most one of the expanded dims can be dynamic.
+static Value getOutputValueForExpandedOp(OpBuilder &builder, Location loc,
+ AffineMap indexingMap, Value result,
+ const ExpansionInfo &expansionInfo) {
SmallVector<Value, 4> dynamicDims;
SmallVector<int64_t, 4> staticDims;
ShapedType resultType = result.getType().cast<ShapedType>();
ArrayRef<int64_t> origShape = resultType.getShape();
- for (AffineExpr expr : outputIndexingMap.getResults()) {
+ for (AffineExpr expr : indexingMap.getResults()) {
unsigned origDimPos = expr.cast<AffineDimExpr>().getPosition();
- ArrayRef<int64_t> expandedShape(origDimToExpandedShapeMap[origDimPos]);
bool foundDynamic = false;
int64_t linearizedShape = 1;
- for (int64_t extent : expandedShape) {
+ for (int64_t extent : expansionInfo.getExpandedShapeOfDim(origDimPos)) {
if (ShapedType::isDynamic(extent)) {
assert(!foundDynamic &&
"Expanded dimensions of reshape can have only one dynamic dim");
@@ -467,6 +605,79 @@ static Value getOutputValueForExpansion(
resultType.getElementType());
}
+/// Returns the reassociation maps to use in the `linalg.tensor_reshape`
+/// operation to convert the operands of the origial operation to operands of
+/// the expanded operation. The same method is used to compute the
+/// `linalg.tensor_reshape` used to collapse the result of the expanded op to
+/// get the value that can replace all uses of the results of the original op.
+static SmallVector<ReassociationIndices, 4>
+getReassociationForExpansion(AffineMap indexingMap,
+ const ExpansionInfo &expansionInfo) {
+ SmallVector<ReassociationIndices, 4> reassociation;
+ unsigned numReshapeDims = 0;
+ for (AffineExpr expr : indexingMap.getResults()) {
+ unsigned dim = expr.cast<AffineDimExpr>().getPosition();
+ auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
+ auto indices = llvm::to_vector<2>(
+ llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
+ reassociation.emplace_back(std::move(indices));
+ numReshapeDims += numExpandedDims;
+ }
+ return reassociation;
+}
+
+/// Build the body of the expanded IndexedGenericOp. The arguments for the
+/// induction variables of the original operation need to be recovered by
+/// linearizing the arguments of the corresponding dimensions of the expanded
+/// op. For now it is assumed that the shapes of the expanded op needed for
+/// linearization are static.
+static void buildExpandedIndexedGenericOpRegion(
+ PatternRewriter &rewriter, Location loc, Region &originalOpRegion,
+ Region &fusedOpRegion, const ExpansionInfo &expansionInfo) {
+ assert(fusedOpRegion.empty() && "expected fused op to have empty region");
+ // Create an entry block in the fused region with same number of arguments
+ // as the fused op
+ Block *fusedEntryBlock = new Block;
+ fusedOpRegion.push_back(fusedEntryBlock);
+ rewriter.cloneRegionBefore(originalOpRegion, fusedOpRegion,
+ fusedOpRegion.end());
+
+ // Merge the entry block of the fused op with the cloned blocks. For this
+ // compute the value for arguments of the region in the original operation
+ // in terms of the arguments of the fused op. Since the original operation
+ // is expanded, the expanded dimensions need to be folded back to get the
+ // replacement value for the arguments corresponding to interation index.
+ // For now this expects that all the loop ranges are constants, which is
+ // true if the shapes are all static. This has already been checked in the
+ // precondition.
+ using namespace edsc::op;
+ using namespace edsc::intrinsics;
+ OpBuilder::InsertionGuard guard(rewriter);
+ SmallVector<Value, 4> argReplacements(originalOpRegion.getNumArguments());
+ rewriter.setInsertionPointToStart(fusedEntryBlock);
+ edsc::ScopedContext scopedContext(rewriter, loc);
+ IndexType indexType = rewriter.getIndexType();
+ for (auto i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
+ Value linearizedIndex = fusedEntryBlock->addArgument(indexType);
+ ArrayRef<int64_t> expandedDimsShape =
+ expansionInfo.getExpandedShapeOfDim(i).drop_front();
+ for (unsigned shape : expandedDimsShape) {
+ assert(!ShapedType::isDynamic(shape));
+ linearizedIndex = linearizedIndex * std_constant_index(shape);
+ linearizedIndex =
+ linearizedIndex + fusedEntryBlock->addArgument(indexType);
+ }
+ argReplacements[i] = linearizedIndex;
+ }
+ for (auto i : llvm::seq<unsigned>(expansionInfo.getOrigOpNumDims(),
+ argReplacements.size())) {
+ argReplacements[i] =
+ fusedEntryBlock->addArgument(originalOpRegion.getArgument(i).getType());
+ }
+ rewriter.mergeBlocks(fusedEntryBlock->getNextNode(), fusedEntryBlock,
+ argReplacements);
+}
+
/// Implements the fusion of a tensor_reshape op and a generic/indexed_generic
/// op as explained in `isFusableWithReshapeByExpansion`. Assumes that those
/// conditions have been satisfied.
@@ -481,104 +692,22 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank();
RankedTensorType expandedType =
isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType();
- AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex);
- // The reshape is folding/expanding consecutive dimensions. Given the indexing
- // map of the fused tensor find the number of dimensions each of the loops of
- // the original op is expanded into. Also record the shape of the expanded
- // dimensions.
- ArrayRef<int64_t> expandedShape = expandedType.getShape();
- Optional<SmallVector<int64_t, 4>> origOpLoopRange =
- getStaticLoopRanges(linalgOp);
- if (!origOpLoopRange) {
- linalgOp.emitError("unable to find loop range for operation");
+ ExpansionInfo expansionInfo;
+ if (failed(expansionInfo.compute(linalgOp, fusedTensorIndex,
+ reshapeOp.getReassociationMaps(),
+ expandedType.getShape())))
return llvm::None;
- }
- SmallVector<unsigned, 4> numFoldedDims(fusedIndexMap.getNumDims(), 1);
- SmallVector<SmallVector<int64_t, 4>, 4> expandedDimsShape(
- fusedIndexMap.getNumDims());
- auto reassociationMaps = reshapeOp.getReassociationMaps();
- for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
- unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
- AffineMap foldedDims = reassociationMaps[resultExpr.index()];
- numFoldedDims[pos] = foldedDims.getNumResults();
- ArrayRef<int64_t> shape =
- expandedShape.slice(foldedDims.getDimPosition(0), numFoldedDims[pos]);
- expandedDimsShape[pos].assign(shape.begin(), shape.end());
- }
- // The remaining dimensions remain the same.
- for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
- if (expandedDimsShape[i].empty())
- expandedDimsShape[i] = {(*origOpLoopRange)[i]};
-
- if (isa<IndexedGenericOp>(linalgOp.getOperation())) {
- // For indexed generic op, the region contains arguments that represent the
- // induction variable value of the loops. In the fused op these values are
- // obtained by linearizing the expanded dimensions. For now just check that
- // the extents used in the linearization (all the expanded dims except the
- // front) are statically know. For dynamic case, we would need shape
- // information on these dimensions to get these.
- for (auto &expandedShape : expandedDimsShape) {
- if (expandedShape.size() == 1)
- continue;
- for (int64_t expandedDimShape : llvm::make_range(
- std::next(expandedShape.begin()), expandedShape.end())) {
- if (ShapedType::isDynamic(expandedDimShape)) {
- linalgOp.emitError(
- "unable to fuse indexed generic op where the expanded dim is "
- "dynamic");
- return llvm::None;
- }
- }
- }
- }
- // The remapping of the indices is then the prefix sum (inclusive) of the
- // numFoldedDims.
- SmallVector<unsigned, 4> remapping(numFoldedDims.size() + 1, 0);
- unsigned sum = 0;
- for (auto numFoldedDim : llvm::enumerate(numFoldedDims)) {
- sum += numFoldedDim.value();
- remapping[numFoldedDim.index() + 1] = sum;
- }
+ if (isa<IndexedGenericOp>(linalgOp.getOperation()) &&
+ failed(isIndexedGenericOpExpandable(linalgOp, expansionInfo)))
+ return llvm::None;
- SmallVector<AffineMap, 4> expandedOpIndexingMaps;
- // Compute the modified indexing maps by replacing every loop (AffineDimExpr)
- // in the original indexing map with the sequence of loops that it is expanded
- // to.
- for (AffineMap indexingMap : linalgOp.getIndexingMaps()) {
- SmallVector<AffineExpr, 4> newExprs;
- for (AffineExpr expr : indexingMap.getResults()) {
- unsigned pos = expr.cast<AffineDimExpr>().getPosition();
- for (unsigned newPos :
- llvm::seq<unsigned>(remapping[pos], remapping[pos + 1])) {
- newExprs.push_back(rewriter.getAffineDimExpr(newPos));
- }
- }
- expandedOpIndexingMaps.push_back(
- AffineMap::get(remapping.back(), indexingMap.getNumSymbols(), newExprs,
- rewriter.getContext()));
- }
+ SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
+ llvm::map_range(linalgOp.getIndexingMaps(), [&](AffineMap m) {
+ return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
+ }));
- // The operands of the expanded op are computed by reshaping the original
- // operands. The reshape depends on the ordering of the loop used to access
- // the tensor in the original operation, and are expanded into as many
- // dimensions as the loop is expanded into (as computed by `remapping`).
- auto getReshapeInfo =
- [&](AffineMap operandIndexingMap,
- SmallVectorImpl<ReassociationIndices> &reassociation,
- SmallVectorImpl<int64_t> &expandedOpOperandShape) {
- unsigned reshapeDims = 0;
- for (AffineExpr expr : operandIndexingMap.getResults()) {
- unsigned origDim = expr.cast<AffineDimExpr>().getPosition();
- auto foldedDims = llvm::seq<int64_t>(
- reshapeDims, reshapeDims + numFoldedDims[origDim]);
- reassociation.emplace_back(foldedDims.begin(), foldedDims.end());
- expandedOpOperandShape.append(expandedDimsShape[origDim].begin(),
- expandedDimsShape[origDim].end());
- reshapeDims += numFoldedDims[origDim];
- }
- };
SmallVector<Value, 4> expandedOpOperands;
for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
if (operand.index() == fusedTensorIndex) {
@@ -586,36 +715,31 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
continue;
}
AffineMap indexingMap = linalgOp.getInputIndexingMap(operand.index());
- SmallVector<ReassociationIndices, 4> reassociation;
- SmallVector<int64_t, 4> expandedOperandShape;
- getReshapeInfo(indexingMap, reassociation, expandedOperandShape);
- Type expandedOperandType = RankedTensorType::get(
- expandedOperandShape,
- operand.value().getType().cast<ShapedType>().getElementType());
+ RankedTensorType expandedOperandType =
+ getExpandedType(operand.value().getType().cast<RankedTensorType>(),
+ indexingMap, expansionInfo);
if (expandedOperandType != operand.value().getType()) {
+ // Reshape the operand to get the right type.
+ SmallVector<ReassociationIndices, 4> reassociation =
+ getReassociationForExpansion(indexingMap, expansionInfo);
expandedOpOperands.push_back(rewriter.create<TensorReshapeOp>(
linalgOp.getLoc(), expandedOperandType, operand.value(),
reassociation));
- } else {
- expandedOpOperands.push_back(operand.value());
+ continue;
}
+ expandedOpOperands.push_back(operand.value());
}
Location loc = linalgOp.getLoc();
SmallVector<Value, 1> outputs;
- SmallVector<SmallVector<ReassociationIndices, 4>, 1> resultReassociation;
for (auto result : llvm::enumerate(linalgOp.getOutputs())) {
AffineMap indexingMap = linalgOp.getOutputIndexingMap(result.index());
- SmallVector<ReassociationIndices, 4> reassociation;
- SmallVector<int64_t, 4> expandedResultShape;
- getReshapeInfo(indexingMap, reassociation, expandedResultShape);
- outputs.push_back(getOutputValueForExpansion(
- rewriter, loc, indexingMap, result.value(), expandedDimsShape));
- resultReassociation.emplace_back(std::move(reassociation));
+ outputs.push_back(getOutputValueForExpandedOp(
+ rewriter, loc, indexingMap, result.value(), expansionInfo));
}
// The iterator types of the expanded op are all parallel.
- SmallVector<StringRef, 4> iteratorTypes(remapping.back(),
+ SmallVector<StringRef, 4> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
getParallelIteratorTypeName());
TypeRange resultTypes = ValueRange(outputs).getTypes();
@@ -631,48 +755,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
fusedRegion.begin());
} else {
assert(isa<IndexedGenericOp>(linalgOp.getOperation()));
- // Create an entry block in the fused Region with same number of arguments
- // as the fused op
- Block *fusedEntryBlock = new Block;
- fusedRegion.push_back(fusedEntryBlock);
- rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.end());
-
- // Merge the entry block of the fused op with the cloned blocks. For this
- // compute the value for arguments of the region in the original operation
- // in terms of the arguments of the fused op. Since the original operation
- // is expanded, the expanded dimensions need to be folded back to get the
- // replacement value for the arguments corresponding to interation index.
- // For now this expects that all the loop ranges are constants, which is
- // true if the shapes are all static. This has already been checked in the
- // precondition.
- using namespace edsc::op;
- using namespace edsc::intrinsics;
- OpBuilder::InsertionGuard guard(rewriter);
- SmallVector<Value, 4> argReplacements(originalRegion.getNumArguments());
- rewriter.setInsertionPointToStart(fusedEntryBlock);
- edsc::ScopedContext scopedContext(rewriter, fusedOp.getLoc());
- IndexType indexType = rewriter.getIndexType();
- for (unsigned i : llvm::seq<unsigned>(0, numFoldedDims.size())) {
- Value linearizedIndex = fusedEntryBlock->addArgument(indexType);
- for (unsigned foldedDim = remapping[i] + 1; foldedDim != remapping[i + 1];
- foldedDim++) {
- int64_t expandedDimExtent =
- expandedDimsShape[i][foldedDim - remapping[i]];
- assert(!ShapedType::isDynamic(expandedDimExtent));
- linearizedIndex =
- linearizedIndex * std_constant_index(expandedDimExtent);
- linearizedIndex =
- linearizedIndex + fusedEntryBlock->addArgument(indexType);
- }
- argReplacements[i] = linearizedIndex;
- }
- for (unsigned i :
- llvm::seq<unsigned>(numFoldedDims.size(), argReplacements.size())) {
- argReplacements[i] =
- fusedEntryBlock->addArgument(originalRegion.getArgument(i).getType());
- }
- rewriter.mergeBlocks(fusedEntryBlock->getNextNode(), fusedEntryBlock,
- argReplacements);
+ buildExpandedIndexedGenericOpRegion(rewriter, loc, originalRegion,
+ fusedRegion, expansionInfo);
}
// Reshape the result values to their original shape if this is a collapsing
@@ -681,10 +765,12 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
for (auto result : llvm::enumerate(linalgOp->getResults())) {
if (!isExpanding &&
resultTypes[result.index()] != result.value().getType()) {
+ SmallVector<ReassociationIndices, 4> reassociation =
+ getReassociationForExpansion(
+ linalgOp.getOutputIndexingMap(result.index()), expansionInfo);
resultVals.push_back(rewriter.create<TensorReshapeOp>(
linalgOp.getLoc(), result.value().getType(),
- fusedOp->getResult(result.index()),
- resultReassociation[result.index()]));
+ fusedOp->getResult(result.index()), reassociation));
} else {
resultVals.push_back(fusedOp->getResult(result.index()));
}
More information about the llvm-branch-commits
mailing list