[Mlir-commits] [mlir] 1227b8a - [mlir] Rename getTied* methods to getMatching* in LinalgInterface.
Oleg Shyshkov
llvmlistbot at llvm.org
Fri Sep 30 03:06:06 PDT 2022
Author: Oleg Shyshkov
Date: 2022-09-30T10:05:45Z
New Revision: 1227b8ab54f902f3016320f75b8dde5b2d540c31
URL: https://github.com/llvm/llvm-project/commit/1227b8ab54f902f3016320f75b8dde5b2d540c31
DIFF: https://github.com/llvm/llvm-project/commit/1227b8ab54f902f3016320f75b8dde5b2d540c31.diff
LOG: [mlir] Rename getTied* methods to getMatching* in LinalgInterface.
Summary:
As mentioned in the comment to https://reviews.llvm.org/D134444, the term `tied`
is a misnomer in this context and `matching` sounds much better.
Differential Revision: https://reviews.llvm.org/D134534
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
index 350b41ac62535..e49fe0c54add7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
+++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
@@ -82,8 +82,8 @@ class LinalgDependenceGraph {
if (!owner)
return llvm::None;
if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
- return owner.getTiedIndexingMap(operand);
- return owner.getTiedIndexingMap(owner.getOutputOperand(
+ return owner.getMatchingIndexingMap(operand);
+ return owner.getMatchingIndexingMap(owner.getOutputOperand(
opView.get<Value>().cast<OpResult>().getResultNumber()));
}
// Return the operand number if the `opView` is an OpOperand *. Otherwise
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 299f931b77198..b72d0944ded67 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -377,7 +377,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
Return the block argument for an `opOperand`.
}],
/*retTy=*/"BlockArgument",
- /*methodName=*/"getTiedBlockArgument",
+ /*methodName=*/"getMatchingBlockArgument",
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@@ -390,7 +390,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
Return the operand for a `blockArgument`.
}],
/*retTy=*/"OpOperand *",
- /*methodName=*/"getTiedOpOperand",
+ /*methodName=*/"getMatchingOpOperand",
/*args=*/(ins "BlockArgument":$blockArgument),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@@ -404,7 +404,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
Return the input or output indexing map for `opOperand`.
}],
/*retTy=*/"AffineMap",
- /*methodName=*/"getTiedIndexingMap",
+ /*methodName=*/"getMatchingIndexingMap",
/*args=*/(ins "OpOperand*":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@@ -419,7 +419,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
Return the indexing map for a `result`.
}],
/*retTy=*/"AffineMap",
- /*methodName=*/"getTiedIndexingMapForResult",
+ /*methodName=*/"getIndexingMapMatchingResult",
/*args=*/(ins "OpResult":$result),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@@ -442,7 +442,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
`opOperand`.
}],
/*retTy=*/"OpOperand *",
- /*methodName=*/"getTiedYieldValue",
+ /*methodName=*/"getMatchingYieldValue",
/*args=*/(ins "OpOperand*":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index b8a93c0408f25..8e7c797e34456 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -34,7 +34,7 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
for (auto *opOperand : linalgOp.getInputAndOutputOperands()) {
if (llvm::is_contained(droppedOperands, opOperand))
continue;
- indexingMaps.push_back(linalgOp.getTiedIndexingMap(opOperand));
+ indexingMaps.push_back(linalgOp.getMatchingIndexingMap(opOperand));
}
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
}
@@ -658,7 +658,7 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
<< linalgOp.getNumInputsAndOutputs() << ")";
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
- AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
+ AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
// Symbols disallowed.
if (indexingMap.getNumSymbols() != 0)
@@ -696,7 +696,7 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
for (int64_t &range : endLoopRangeValues)
range -= 1;
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
- AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
+ AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
SmallVector<int64_t, 4> startIndices =
indexingMap.compose(startLoopRangeValues);
SmallVector<int64_t, 4> endIndices =
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b9bf824808e8a..256e61ee03547 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -945,7 +945,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
// Check if this operand is a duplicate.
AffineMap indexingMap =
- genericOp.getTiedIndexingMap(inputOpOperand.value());
+ genericOp.getMatchingIndexingMap(inputOpOperand.value());
auto it = dedupedInputs.find(
std::make_pair(inputOpOperand.value()->get(), indexingMap));
if (it != dedupedInputs.end()) {
@@ -984,7 +984,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
newOutputOperands.push_back(outputOpOperand.value()->get());
newIndexingMaps.push_back(
- genericOp.getTiedIndexingMap(outputOpOperand.value()));
+ genericOp.getMatchingIndexingMap(outputOpOperand.value()));
}
} else {
// Output argument can be dropped if the result has
@@ -997,7 +997,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
llvm::enumerate(genericOp.getOutputOperands())) {
Value result = genericOp.getResult(outputOpOperand.index());
AffineMap indexingMap =
- genericOp.getTiedIndexingMap(outputOpOperand.value());
+ genericOp.getMatchingIndexingMap(outputOpOperand.value());
auto key =
std::make_tuple(outputOpOperand.value()->get(), indexingMap,
yieldOp->getOperand(outputOpOperand.index()));
@@ -1033,7 +1033,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
dedupedOutpts[key] = newOutputOperands.size();
newOutputOperands.push_back(outputOpOperand.value()->get());
newIndexingMaps.push_back(
- genericOp.getTiedIndexingMap(outputOpOperand.value()));
+ genericOp.getMatchingIndexingMap(outputOpOperand.value()));
}
}
@@ -1957,7 +1957,7 @@ static void populateMap(LinalgOp linalgOp, ArrayRef<OpOperand *> operands,
continue;
Value src = opOperand->get();
auto sourceType = src.getType().cast<RankedTensorType>();
- auto sourceMap = linalgOp.getTiedIndexingMap(opOperand);
+ auto sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
// Get the `sourceShape` of the `sourceType`. If the operand is a result of
// `tensor.cast` operation and source of the cast operation has a static
@@ -2005,7 +2005,7 @@ static void createNewOperandWithStaticSizes(
return;
}
ArrayRef<int64_t> sourceShape = sourceType.getShape();
- AffineMap sourceMap = linalgOp.getTiedIndexingMap(opOperand);
+ AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
SmallVector<int64_t> newShape;
// If operand is updated with new shape, `newOperandNeeded` will be
// true.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
index 60156d9bfeb3b..2cf8a57f3fc83 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
@@ -81,7 +81,7 @@ struct BubbleUpExtractSliceOpPattern
}
OpOperand *outOperand = linalgOp.getOutputOperand(0);
- AffineMap indexingMap = linalgOp.getTiedIndexingMap(outOperand);
+ AffineMap indexingMap = linalgOp.getMatchingIndexingMap(outOperand);
if (!indexingMap.isProjectedPermutation()) {
return rewriter.notifyMatchFailure(
sliceOp, "expected a projected permutation for output");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
index 192927f14a681..6bdb33b831ae1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
@@ -180,7 +180,7 @@ DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
OpResult result = genericOp.getResult(*resultNumber).cast<OpResult>();
newResultTypes.push_back(result.getType());
peeledGenericOpIndexingMaps.push_back(
- genericOp.getTiedIndexingMapForResult(result));
+ genericOp.getIndexingMapMatchingResult(result));
continue;
}
@@ -227,15 +227,16 @@ DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
/// as those used for the new results of the peeledGenericOp.
auto indexingMaps = llvm::to_vector(
llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *operand) {
- return genericOp.getTiedIndexingMap(operand);
+ return genericOp.getMatchingIndexingMap(operand);
}));
for (auto resultNum :
llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
OpResult result = peeledGenericOp.getResult(resultNum).cast<OpResult>();
- indexingMaps.push_back(peeledGenericOp.getTiedIndexingMapForResult(result));
+ indexingMaps.push_back(
+ peeledGenericOp.getIndexingMapMatchingResult(result));
}
for (OpOperand *outOperand : genericOp.getOutputOperands())
- indexingMaps.push_back(genericOp.getTiedIndexingMap(outOperand));
+ indexingMaps.push_back(genericOp.getMatchingIndexingMap(outOperand));
auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps);
return rewriter.create<GenericOp>(
@@ -263,7 +264,7 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
}
if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) {
- return !genericOp.getTiedIndexingMap(outOperand).isPermutation();
+ return !genericOp.getMatchingIndexingMap(outOperand).isPermutation();
})) {
return rewriter.notifyMatchFailure(
genericOp, "unhandled decomposition of generic op with out operand not "
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index f09286de97ea5..9a5614a8fad1a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -245,7 +245,7 @@ struct UnitExtentReplacementInfo {
static llvm::Optional<UnitExtentReplacementInfo>
replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
MLIRContext *context) {
- AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
+ AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
SmallVector<AffineExpr> reassociations;
@@ -390,7 +390,7 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
// type, indexing map, and create a set of mappings representing an
// identity matrix.
newInputOutputTypes.push_back(opOperand->get().getType());
- newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
+ newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(opOperand));
int64_t origRank = genericOp.getRank(opOperand);
auto maps = llvm::to_vector<8>(llvm::map_range(
llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f114b73315fb9..eb2f2f1c3cee8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -59,7 +59,7 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
// argMap is a map from producer loop -> producer arg tensor index.
- AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
+ AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
// Compose argMap with invProducerResultIndexMap to get a map from
// producer result tensor index -> producer arg tensor index.
@@ -95,14 +95,14 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
// Get the consumer index map. The number of results of the consumer index
// map must match the number of loops of the producer.
- AffineMap consumerIndexMap = consumer.getTiedIndexingMap(fusedOperand);
+ AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
if (consumerIndexMap.getNumResults() != producer.getNumLoops())
return false;
// Finally the index_map for the result must be invertible. For now just
// verify it is a permutation.
AffineMap producerResultIndexMap =
- producer.getTiedIndexingMap(producer.getOutputOperand(0));
+ producer.getMatchingIndexingMap(producer.getOutputOperand(0));
if (!producerResultIndexMap.isPermutation())
return false;
@@ -288,17 +288,17 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
assert(it != consumerInputs.end() && "expected to find the consumer operand");
for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
fusedInputOperands.push_back(opOperand->get());
- fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
+ fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
}
// 4. Splice in producer's input operands/maps.
AffineMap producerResultIndexMap =
- producer.getTiedIndexingMapForResult(producerResult);
+ producer.getIndexingMapMatchingResult(producerResult);
for (OpOperand *opOperand : producer.getInputOperands()) {
fusedInputOperands.push_back(opOperand->get());
// Compute indexing maps for the producer args in the fused operation.
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
opOperand, producerResultIndexMap,
- consumer.getTiedIndexingMap(fusedOperand));
+ consumer.getMatchingIndexingMap(fusedOperand));
fusedIndexMaps.push_back(map);
}
// 5. Remaining consumer's input operands/maps (drop past index
@@ -306,7 +306,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
for (OpOperand *opOperand :
llvm::make_range(std::next(it), consumerInputs.end())) {
fusedInputOperands.push_back(opOperand->get());
- fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
+ fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
}
// 6. Collect all of the producer outputs.
@@ -314,7 +314,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
fusedOutputOperands.push_back(opOperand->get());
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
opOperand, producerResultIndexMap,
- consumer.getTiedIndexingMap(fusedOperand));
+ consumer.getMatchingIndexingMap(fusedOperand));
fusedIndexMaps.push_back(map);
fusedResultTypes.push_back(opOperand->get().getType());
}
@@ -322,7 +322,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
// 7. All of consumer's output operands (skip operands: added by the builder).
for (OpOperand *opOperand : consumer.getOutputOperands()) {
fusedOutputOperands.push_back(opOperand->get());
- fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
+ fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
fusedResultTypes.push_back(opOperand->get().getType());
}
@@ -344,7 +344,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
// Construct an AffineMap from consumer loops to producer loops.
// consumer loop -> tensor index
- AffineMap consumerResultIndexMap = consumer.getTiedIndexingMap(fusedOperand);
+ AffineMap consumerResultIndexMap =
+ consumer.getMatchingIndexingMap(fusedOperand);
// tensor index -> producer loop
AffineMap invProducerResultIndexMap =
inversePermutation(producerResultIndexMap);
@@ -466,7 +467,7 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
.getValue()
.isProjectedPermutation();
}) &&
- genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
+ genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() > 0 &&
llvm::all_of(genericOp.getIteratorTypesArray(), [](StringRef it) {
return it == getParallelIteratorTypeName();
});
@@ -517,7 +518,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
PatternRewriter &rewriter) {
if (reassociationMaps.empty())
return failure();
- AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
+ AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
@@ -727,7 +728,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
continue;
}
if (genericOp.isInputTensor(opOperand)) {
- AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
+ AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
RankedTensorType expandedOperandType =
getExpandedType(opOperandType, indexingMap, expansionInfo);
@@ -755,7 +756,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
Location loc = genericOp.getLoc();
SmallVector<Value> outputs;
for (OpOperand *opOperand : genericOp.getOutputOperands()) {
- AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
+ AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
RankedTensorType expandedOutputType =
getExpandedType(opOperandType, indexingMap, expansionInfo);
@@ -802,7 +803,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
if (resultTypes[resultNumber] != opResult.getType()) {
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(
- genericOp.getTiedIndexingMap(
+ genericOp.getMatchingIndexingMap(
genericOp.getOutputOperand(resultNumber)),
expansionInfo);
resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
@@ -1063,7 +1064,7 @@ getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
}
llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
- AffineMap indexingMap = genericOp.getTiedIndexingMap(fusableOperand);
+ AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
auto iteratorTypes = genericOp.getIteratorTypes().getValue();
SmallVector<ReassociationIndices> iterationSpaceReassociation;
for (ReassociationIndicesRef foldedRangeDims : reassociation) {
@@ -1312,7 +1313,7 @@ static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
OpOperand *opOperand,
const CollapsingInfo &collapsingInfo,
OpBuilder &builder) {
- AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
+ AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
SmallVector<ReassociationIndices> operandReassociation =
getOperandReassociation(indexingMap, collapsingInfo);
@@ -1470,7 +1471,7 @@ static FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>();
if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
AffineMap indexingMap =
- genericOp.getTiedIndexingMapForResult(originalResult.value());
+ genericOp.getIndexingMapMatchingResult(originalResult.value());
SmallVector<ReassociationIndices> reassociation =
getOperandReassociation(indexingMap, collapsingInfo);
Value result = rewriter.create<tensor::ExpandShapeOp>(
@@ -1594,12 +1595,14 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
if (inputOperand == opOperand)
continue;
Value inputValue = inputOperand->get();
- fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
+ fusedIndexMaps.push_back(
+ genericOp.getMatchingIndexingMap(inputOperand));
fusedOperands.push_back(inputValue);
fusedLocs.push_back(inputValue.getLoc());
}
for (OpOperand *outputOperand : genericOp.getOutputOperands())
- fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
+ fusedIndexMaps.push_back(
+ genericOp.getMatchingIndexingMap(outputOperand));
// Check if the operation shapes to loops map is computable.
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 685ad1e5aa20e..5738d51373493 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -80,7 +80,7 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
opOperand->get().getDefiningOp()))
continue;
- AffineMap map = op.getTiedIndexingMap(opOperand);
+ AffineMap map = op.getMatchingIndexingMap(opOperand);
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: "
<< opOperand->getOperandNumber() << "\n");
LLVM_DEBUG(llvm::dbgs()
@@ -442,7 +442,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
OpOperand *opOperand =
producerOp.getOutputOperand(producerOpResult.getResultNumber());
LinalgOp fusedProducer =
- fuse(b, producerOp, producerOp.getTiedIndexingMap(opOperand),
+ fuse(b, producerOp, producerOp.getMatchingIndexingMap(opOperand),
consumerOpOperand);
// Replace use.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index ff28663d479b7..d4e3b52f30e3f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -38,7 +38,7 @@ static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
ArrayRef<int64_t> tiledLoopDims) {
// Get the consumer operand indexing map.
LinalgOp consumerOp = consumerOperand->getOwner();
- AffineMap indexingMap = consumerOp.getTiedIndexingMap(consumerOperand);
+ AffineMap indexingMap = consumerOp.getMatchingIndexingMap(consumerOperand);
// Search the slice dimensions tiled by a tile loop dimension.
DenseSet<int64_t> tiledSliceDimIndices;
@@ -68,7 +68,7 @@ getTiledProducerLoops(OpResult producerResult,
// Get the indexing map of the `producerOp` output operand that matches
// ´producerResult´.
- AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(
+ AffineMap producerIndexingMap = producerOp.getMatchingIndexingMap(
producerOp.getOutputOperand(producerResult.getResultNumber()));
// Keep only the tiled result slice dimensions of `producerIndexingMap`.
@@ -351,7 +351,7 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
// Check `consumerOpOperand` is not shape-only to avoid fusion if the data is
// not used by the `consumerOp` computation.
- BlockArgument bbArg = consumerOp.getTiedBlockArgument(consumerOpOperand);
+ BlockArgument bbArg = consumerOp.getMatchingBlockArgument(consumerOpOperand);
if (bbArg.getUses().empty())
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
index 54c92e927c4e7..04e94b1014e49 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
@@ -42,7 +42,7 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
SmallVector<AffineMap> newIndexingMaps;
SmallVector<Value> newOperands;
for (OpOperand *opOperand : genericOp.getInputOperands()) {
- AffineMap map = genericOp.getTiedIndexingMap(opOperand);
+ AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
if (genericOp.isInputTensor(opOperand) && map.isConstant()) {
scalarOperands.emplace_back(opOperand->getOperandNumber());
} else {
@@ -55,7 +55,7 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
return failure();
for (OpOperand *opOperand : genericOp.getOutputOperands())
- newIndexingMaps.emplace_back(genericOp.getTiedIndexingMap(opOperand));
+ newIndexingMaps.emplace_back(genericOp.getMatchingIndexingMap(opOperand));
Location loc = genericOp->getLoc();
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
@@ -71,7 +71,7 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
for (auto idx : llvm::reverse(scalarOperands)) {
OpOperand *opOperand = genericOp.getInputOperand(idx);
- AffineMap map = genericOp.getTiedIndexingMap(opOperand);
+ AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
SmallVector<int64_t> indices = map.getConstantResults();
SmallVector<Value> indicesValues;
for (auto idx : indices)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
index 5065bc5b4d6f2..8641e1106310e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
@@ -68,7 +68,7 @@ mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
// 2. Compute the interchanged indexing maps.
SmallVector<AffineMap> newIndexingMaps;
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
- AffineMap m = genericOp.getTiedIndexingMap(opOperand);
+ AffineMap m = genericOp.getMatchingIndexingMap(opOperand);
if (!permutationMap.isEmpty())
m = m.compose(permutationMap);
newIndexingMaps.push_back(m);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 6384b1fb92672..a14994f1bd077 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -144,14 +144,14 @@ static void emitScalarImplementation(OpBuilder &b, Location loc,
continue;
}
auto indexing = makeCanonicalAffineApplies(
- b, loc, linalgOp.getTiedIndexingMap(inputOperand), allIvsPlusDims);
+ b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims);
indexedValues.push_back(
b.create<LoadOpTy>(loc, inputOperand->get(), indexing));
}
// 1.b. Emit load from output views.
for (OpOperand *outputOperand : linalgOp.getOutputOperands()) {
SmallVector<Value> indexing = makeCanonicalAffineApplies(
- b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims);
+ b, loc, linalgOp.getMatchingIndexingMap(outputOperand), allIvsPlusDims);
indexedValues.push_back(
b.create<LoadOpTy>(loc, outputOperand->get(), indexing));
}
@@ -163,7 +163,8 @@ static void emitScalarImplementation(OpBuilder &b, Location loc,
SmallVector<Value> outputBuffers;
for (OpOperand *outputOperand : linalgOp.getOutputBufferOperands()) {
indexing.push_back(makeCanonicalAffineApplies(
- b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims));
+ b, loc, linalgOp.getMatchingIndexingMap(outputOperand),
+ allIvsPlusDims));
outputBuffers.push_back(outputOperand->get());
}
inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 61eb384c8d750..56a7437d01f69 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -117,7 +117,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
SmallVector<AffineMap> newMaps;
// Calculate the new shapes and indexing maps of the input operands.
for (OpOperand *operand : op.getInputOperands()) {
- AffineMap map = op.getTiedIndexingMap(operand);
+ AffineMap map = op.getMatchingIndexingMap(operand);
SmallVector<int64_t> newShape;
SmallVector<AffineExpr> exprs;
SmallVector<ReassociationIndices> reassociation;
@@ -171,7 +171,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
// Calculate the new output map and shape, we insert the new dimension based
// on the index returned by `controlSplitReductionFn`.
SmallVector<int64_t> newOutputShape;
- AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0));
+ AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getOutputOperand(0));
ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0));
SmallVector<AffineExpr> outputExpr;
for (unsigned idx :
@@ -273,7 +273,7 @@ static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand,
int64_t reductionRatio) {
auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext());
- AffineMap map = op.getTiedIndexingMap(&opOperand);
+ AffineMap map = op.getMatchingIndexingMap(&opOperand);
AffineMap idMap =
AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
@@ -286,7 +286,7 @@ static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand,
static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
unsigned reductionDimPos, int64_t size) {
auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
- AffineMap map = op.getTiedIndexingMap(&opOperand);
+ AffineMap map = op.getMatchingIndexingMap(&opOperand);
AffineMap idMap =
AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 177db809b4196..a9113f0d05713 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -62,7 +62,7 @@ static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp,
Value toStore = map.lookupOrDefault(operand.value());
OpOperand *storeInto = linalgOp.getOutputOperand(operand.index());
auto indices = getIndicesForAccess(
- b, loc, linalgOp.getTiedIndexingMap(storeInto), ivs);
+ b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs);
b.create<memref::StoreOp>(loc, toStore,
linalgOp.getOutputOperand(operand.index())->get(),
indices);
@@ -162,10 +162,10 @@ struct LinalgOpTilingInterface
}));
OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
- SliceParameters sliceParams =
- computeSliceParameters(b, loc, outOperand->get(), sizes,
- linalgOp.getTiedIndexingMap(outOperand), offsets,
- /*ubs*/ {}, subShapeSizes, true);
+ SliceParameters sliceParams = computeSliceParameters(
+ b, loc, outOperand->get(), sizes,
+ linalgOp.getMatchingIndexingMap(outOperand), offsets,
+ /*ubs*/ {}, subShapeSizes, true);
resultOffsets = sliceParams.offsets;
resultSizes = sliceParams.sizes;
return success();
@@ -182,7 +182,7 @@ struct LinalgOpTilingInterface
// map the offsets and sizes from the result to iteration space tiles
// (filling in full extent for dimensions not used to access the result).
AffineMap indexingMap =
- linalgOp.getTiedIndexingMapForResult(op->getResult(resultNumber));
+ linalgOp.getIndexingMapMatchingResult(op->getResult(resultNumber));
if (!indexingMap.isProjectedPermutation()) {
return op->emitOpError(
"unhandled tiled implementation generation when result is not "
@@ -238,7 +238,7 @@ struct LinalgOpTilingInterface
continue;
}
SmallVector<Value> indices = getIndicesForAccess(
- builder, linalgOpLoc, linalgOp.getTiedIndexingMap(operand), ivs);
+ builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(operand), ivs);
Value load =
builder.create<memref::LoadOp>(linalgOpLoc, operand->get(), indices);
indexedValues.push_back(load);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 0cbc975090c04..575dfbbc0909b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -172,7 +172,7 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
ArrayRef<int64_t> paddingDimensions, ArrayRef<Attribute> paddingValues,
ArrayRef<bool> packPaddings) {
- AffineMap indexingMap = opToPad.getTiedIndexingMap(opOperand);
+ AffineMap indexingMap = opToPad.getMatchingIndexingMap(opOperand);
ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
// Collect the shape dimension that are a function of the `paddingDimensions`.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index eaea7da592356..46dc2324eff2d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -215,7 +215,7 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
if (vectorType.getRank() > 0) {
// 0-d case is still special: do not invert the reindexing map.
AffineMap map =
- reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand));
+ reindexIndexingMap(linalgOp.getMatchingIndexingMap(outputOperand));
SmallVector<int64_t> transposeShape =
applyPermutationMap(inversePermutation(map), vectorType.getShape());
assert(!transposeShape.empty() && "unexpected empty transpose shape");
@@ -479,12 +479,12 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
// } else {
if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
map = inverseAndBroadcastProjectedPermutation(
- linalgOp.getTiedIndexingMap(opOperand));
+ linalgOp.getMatchingIndexingMap(opOperand));
readType = VectorType::get(commonVectorShape,
getElementTypeOrSelf(opOperand->get()));
} else {
map = inversePermutation(
- reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
+ reindexIndexingMap(linalgOp.getMatchingIndexingMap(opOperand)));
readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
getElementTypeOrSelf(opOperand->get()));
}
@@ -545,7 +545,7 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
return failure();
}
for (OpOperand *opOperand : op.getOutputOperands()) {
- AffineMap indexingMap = op.getTiedIndexingMap(opOperand);
+ AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
if (indexingMap.isPermutation())
continue;
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index c82cacdc3db44..d39fa11f364a2 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -180,7 +180,7 @@ bool isElementwise(LinalgOp op) {
// TODO: relax the restrictions on indexing map.
for (OpOperand *opOperand : op.getOutputOperands()) {
- if (!op.getTiedIndexingMap(opOperand).isPermutation())
+ if (!op.getMatchingIndexingMap(opOperand).isPermutation())
return false;
}
return hasOnlyScalarElementwiseOp(op->getRegion(0));
@@ -967,7 +967,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
Value shapedOp = valuesToTile[opOperand->getOperandNumber()];
LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
- AffineMap map = linalgOp.getTiedIndexingMap(opOperand);
+ AffineMap map = linalgOp.getMatchingIndexingMap(opOperand);
// Use `opOperand` as is if it is not tiled and not an output tensor. Having
// an extract/insert slice pair for all output tensors simplifies follow up
// transformations such as padding and bufferization since the
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index f4519346229d5..18dd53947bab2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -170,9 +170,9 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
if (!op.hasTensorSemantics() || op.getNumInputs() != 2 ||
op.getNumResults() != 1 ||
op.getNumParallelLoops() != op.getNumLoops() ||
- !op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() ||
- !op.getTiedIndexingMap(op.getInputOperand(0)).isIdentity() ||
- !op.getTiedIndexingMap(op.getInputOperand(1)).isIdentity())
+ !op.getMatchingIndexingMap(op.getOutputOperand(0)).isIdentity() ||
+ !op.getMatchingIndexingMap(op.getInputOperand(0)).isIdentity() ||
+ !op.getMatchingIndexingMap(op.getInputOperand(1)).isIdentity())
return failure();
// Find consuming OP2(sparse, other) or OP2(other, sparse). The other
// operand can be sparse or dense, since the point of this rewriting rule
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 7019058720bad..ddcf839b1822e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -195,7 +195,7 @@ static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a,
static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
bool annotated = false;
for (OpOperand *t : op.getInputAndOutputOperands()) {
- auto map = op.getTiedIndexingMap(t);
+ auto map = op.getMatchingIndexingMap(t);
auto enc = getSparseTensorEncoding(t->get().getType());
if (enc)
annotated = true;
@@ -296,7 +296,7 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
if (t == skip)
continue;
// Get map and encoding.
- auto map = op.getTiedIndexingMap(t);
+ auto map = op.getMatchingIndexingMap(t);
auto enc = getSparseTensorEncoding(t->get().getType());
assert(map.getNumDims() == n);
// Skip dense tensor constraints when not requested.
@@ -542,7 +542,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
for (OpOperand *t : op.getInputAndOutputOperands()) {
unsigned tensor = t->getOperandNumber();
auto shape = op.getShape(t);
- auto map = op.getTiedIndexingMap(t);
+ auto map = op.getMatchingIndexingMap(t);
auto enc = getSparseTensorEncoding(t->get().getType());
// Scan all dimensions of current tensor.
args.clear();
@@ -721,7 +721,7 @@ static Value genAffine(CodeGen &codegen, OpBuilder &builder, AffineExpr a,
/// Generates index for load/store on sparse tensor.
static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) {
- auto map = op.getTiedIndexingMap(t);
+ auto map = op.getMatchingIndexingMap(t);
auto enc = getSparseTensorEncoding(t->get().getType());
AffineExpr a = map.getResult(toOrigDim(enc, map.getNumResults() - 1));
assert(a.getKind() == AffineExprKind::DimId);
@@ -734,7 +734,7 @@ static Value genSubscript(CodeGen &codegen, OpBuilder &builder,
linalg::GenericOp op, OpOperand *t,
SmallVector<Value, 4> &args) {
unsigned tensor = t->getOperandNumber();
- auto map = op.getTiedIndexingMap(t);
+ auto map = op.getMatchingIndexingMap(t);
auto enc = getSparseTensorEncoding(t->get().getType());
unsigned rank = map.getNumResults();
if (enc) {
@@ -1079,7 +1079,7 @@ static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
// Inspect tensor indices.
bool atLevel = ldx == -1u;
OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
- auto map = op.getTiedIndexingMap(t);
+ auto map = op.getMatchingIndexingMap(t);
auto enc = getSparseTensorEncoding(t->get().getType());
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
AffineExpr a = map.getResult(toOrigDim(enc, d));
@@ -1275,7 +1275,7 @@ static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
unsigned idx) {
for (OpOperand *t : op.getInputAndOutputOperands()) {
if (!getSparseTensorEncoding(t->get().getType())) {
- auto map = op.getTiedIndexingMap(t);
+ auto map = op.getMatchingIndexingMap(t);
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
AffineExpr a = map.getResult(d);
// Report non-unit stride if innermost index appears at an outer
@@ -1920,7 +1920,8 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
auto srcTp = tval.getType().cast<RankedTensorType>();
auto dstEnc = SparseTensorEncodingAttr::get(
op->getContext(), srcEnc.getDimLevelType(),
- permute(getContext(), op.getTiedIndexingMap(t), topSort), // new order
+ permute(getContext(), op.getMatchingIndexingMap(t),
+ topSort), // new order
srcEnc.getPointerBitWidth(), srcEnc.getIndexBitWidth());
auto dstTp = RankedTensorType::get(srcTp.getShape(),
srcTp.getElementType(), dstEnc);
More information about the Mlir-commits
mailing list