[Mlir-commits] [mlir] c698505 - [mlir][linalg] Cleanup LinalgOp usage in drop unit dims.
Tobias Gysi
llvmlistbot at llvm.org
Thu Jun 3 05:52:35 PDT 2021
Author: Tobias Gysi
Date: 2021-06-03T12:27:05Z
New Revision: c698505257598d04f8e92a7ee79bfdf7c2cc6020
URL: https://github.com/llvm/llvm-project/commit/c698505257598d04f8e92a7ee79bfdf7c2cc6020
DIFF: https://github.com/llvm/llvm-project/commit/c698505257598d04f8e92a7ee79bfdf7c2cc6020.diff
LOG: [mlir][linalg] Cleanup LinalgOp usage in drop unit dims.
Replace the uses of deprecated Structured Op Interface methods in DropUnitDims.cpp. This patch is based on https://reviews.llvm.org/D103394.
Differential Revision: https://reviews.llvm.org/D103448
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 5e8820535a41..fb5990786802 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -183,9 +183,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
if (!invertedMap)
return failure();
- SmallVector<int64_t, 4> dims;
- for (ShapedType shapedType : genericOp.getShapedOperandTypes())
- dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
+ SmallVector<int64_t> dims = genericOp.getStaticShape();
// Find all the reduction iterators. Those need some special consideration
// (see below).
@@ -267,17 +265,18 @@ struct UnitExtentReplacementInfo {
/// - modified index map that can be used to access the replaced result/operand
/// - the reassociation that converts from the original tensor type to the
/// modified tensor type.
-static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
- RankedTensorType type,
+static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
+ OpOperand *opOperand,
MLIRContext *context) {
- ArrayRef<int64_t> shape = type.getShape();
- ArrayRef<AffineExpr> exprs = indexMap.getResults();
+ AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
+ ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
+ ArrayRef<AffineExpr> exprs = indexingMap.getResults();
SmallVector<AffineExpr, 2> reassociations;
SmallVector<Attribute, 4> reassociationMaps;
SmallVector<AffineExpr, 4> newIndexExprs;
SmallVector<int64_t, 4> newShape;
- int64_t origRank = type.getRank();
+ int64_t origRank = genericOp.getRank(opOperand);
AffineExpr zeroExpr = getAffineConstantExpr(0, context);
auto isUnitExtent = [&](int64_t dim) -> bool {
return shape[dim] == 1 && exprs[dim] == zeroExpr;
@@ -302,8 +301,9 @@ static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
++dim;
}
UnitExtentReplacementInfo info = {
- RankedTensorType::get(newShape, type.getElementType()),
- AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(),
+ RankedTensorType::get(newShape,
+ getElementTypeOrSelf(opOperand->get().getType())),
+ AffineMap::get(indexingMap.getNumDims(), indexingMap.getNumSymbols(),
newIndexExprs, context),
ArrayAttr::get(context, reassociationMaps)};
return info;
@@ -335,15 +335,13 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
SmallVector<ArrayAttr, 4> reassociationMaps;
SmallVector<ShapedType, 4> newInputOutputTypes;
bool doCanonicalization = false;
- for (auto it : llvm::zip(genericOp.getIndexingMaps(),
- genericOp.getShapedOperandTypes())) {
- auto replacementInfo = replaceUnitExtents(
- std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(),
- context);
+
+ for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
+ auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
reassociationMaps.push_back(replacementInfo.reassociation);
newIndexingMaps.push_back(replacementInfo.indexMap);
newInputOutputTypes.push_back(replacementInfo.type);
- doCanonicalization |= replacementInfo.type != std::get<1>(it);
+ doCanonicalization |= replacementInfo.type != opOperand->get().getType();
}
// If the indexing maps of the result operation are not invertible (i.e. not
More information about the Mlir-commits
mailing list