[Mlir-commits] [mlir] f6b4e08 - [mlir][linalg] Prepare drop unit dims for scalar operands.
Tobias Gysi
llvmlistbot at llvm.org
Fri Jun 11 06:19:42 PDT 2021
Author: Tobias Gysi
Date: 2021-06-11T13:18:06Z
New Revision: f6b4e081dc9cf74fb5c22439f552fa035f2c2651
URL: https://github.com/llvm/llvm-project/commit/f6b4e081dc9cf74fb5c22439f552fa035f2c2651
DIFF: https://github.com/llvm/llvm-project/commit/f6b4e081dc9cf74fb5c22439f552fa035f2c2651.diff
LOG: [mlir][linalg] Prepare drop unit dims for scalar operands.
Adapt drop unit dims for structured ops taking scalar operands.
Differential Revision: https://reviews.llvm.org/D103890
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index fb5990786802a..102dbdb4e2c36 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -249,7 +249,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
};
struct UnitExtentReplacementInfo {
- RankedTensorType type;
+ Type type;
AffineMap indexMap;
ArrayAttr reassociation;
};
@@ -271,10 +271,10 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
- SmallVector<AffineExpr, 2> reassociations;
- SmallVector<Attribute, 4> reassociationMaps;
- SmallVector<AffineExpr, 4> newIndexExprs;
- SmallVector<int64_t, 4> newShape;
+ SmallVector<AffineExpr> reassociations;
+ SmallVector<Attribute> reassociationMaps;
+ SmallVector<AffineExpr> newIndexExprs;
+ SmallVector<int64_t> newShape;
int64_t origRank = genericOp.getRank(opOperand);
AffineExpr zeroExpr = getAffineConstantExpr(0, context);
@@ -282,7 +282,7 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
return shape[dim] == 1 && exprs[dim] == zeroExpr;
};
- unsigned dim = 0;
+ int64_t dim = 0;
// Fold dimensions that are unit-extent at the beginning of the tensor.
while (dim < origRank && isUnitExtent(dim))
reassociations.push_back(getAffineDimExpr(dim++, context));
@@ -300,12 +300,16 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
reassociations.clear();
++dim;
}
- UnitExtentReplacementInfo info = {
- RankedTensorType::get(newShape,
- getElementTypeOrSelf(opOperand->get().getType())),
- AffineMap::get(indexingMap.getNumDims(), indexingMap.getNumSymbols(),
- newIndexExprs, context),
- ArrayAttr::get(context, reassociationMaps)};
+ // Compute the tensor or scalar replacement type.
+ Type elementType = getElementTypeOrSelf(opOperand->get().getType());
+ Type replacementType = elementType == opOperand->get().getType()
+ ? elementType
+ : RankedTensorType::get(newShape, elementType);
+ UnitExtentReplacementInfo info = {replacementType,
+ AffineMap::get(indexingMap.getNumDims(),
+ indexingMap.getNumSymbols(),
+ newIndexExprs, context),
+ ArrayAttr::get(context, reassociationMaps)};
return info;
}
@@ -331,13 +335,14 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
MLIRContext *context = rewriter.getContext();
Location loc = genericOp.getLoc();
- SmallVector<AffineMap, 4> newIndexingMaps;
- SmallVector<ArrayAttr, 4> reassociationMaps;
- SmallVector<ShapedType, 4> newInputOutputTypes;
+ SmallVector<AffineMap> newIndexingMaps;
+ SmallVector<ArrayAttr> reassociationMaps;
+ SmallVector<Type> newInputOutputTypes;
bool doCanonicalization = false;
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
- auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
+ UnitExtentReplacementInfo replacementInfo =
+ replaceUnitExtents(genericOp, opOperand, context);
reassociationMaps.push_back(replacementInfo.reassociation);
newIndexingMaps.push_back(replacementInfo.indexMap);
newInputOutputTypes.push_back(replacementInfo.type);
More information about the Mlir-commits
mailing list