[Mlir-commits] [mlir] 9c27fa3 - [mlir][linalg] Prepare fusion on tensors for scalar operands.
Tobias Gysi
llvmlistbot at llvm.org
Wed Jun 9 00:11:13 PDT 2021
Author: Tobias Gysi
Date: 2021-06-09T07:09:46Z
New Revision: 9c27fa3821dc5c04f5710e64411815893de160ce
URL: https://github.com/llvm/llvm-project/commit/9c27fa3821dc5c04f5710e64411815893de160ce
DIFF: https://github.com/llvm/llvm-project/commit/9c27fa3821dc5c04f5710e64411815893de160ce.diff
LOG: [mlir][linalg] Prepare fusion on tensors for scalar operands.
Adapt fusion on tensors to support structured ops taking scalar operands.
Differential Revision: https://reviews.llvm.org/D103889
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index d4dbb5aeb7c27..f65a0fa1772d4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -701,24 +701,27 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
}));
SmallVector<Value> expandedOpOperands;
+ expandedOpOperands.reserve(genericOp.getNumInputs());
for (OpOperand *opOperand : genericOp.getInputOperands()) {
if (opOperand == fusableOpOperand) {
expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src()
: collapsingReshapeOp.src());
continue;
}
- AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
- RankedTensorType expandedOperandType =
- getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
- indexingMap, expansionInfo);
- if (expandedOperandType != opOperand->get().getType()) {
- // Reshape the operand to get the right type.
- SmallVector<ReassociationIndices> reassociation =
- getReassociationForExpansion(indexingMap, expansionInfo);
- expandedOpOperands.push_back(rewriter.create<TensorExpandShapeOp>(
- genericOp.getLoc(), expandedOperandType, opOperand->get(),
- reassociation));
- continue;
+ if (genericOp.isInputTensor(opOperand)) {
+ AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
+ RankedTensorType expandedOperandType =
+ getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
+ indexingMap, expansionInfo);
+ if (expandedOperandType != opOperand->get().getType()) {
+ // Reshape the operand to get the right type.
+ SmallVector<ReassociationIndices> reassociation =
+ getReassociationForExpansion(indexingMap, expansionInfo);
+ expandedOpOperands.push_back(rewriter.create<TensorExpandShapeOp>(
+ genericOp.getLoc(), expandedOperandType, opOperand->get(),
+ reassociation));
+ continue;
+ }
}
expandedOpOperands.push_back(opOperand->get());
}
@@ -1035,7 +1038,7 @@ class FoldWithProducerReshapeOpByExpansion
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- for (OpOperand *opOperand : genericOp.getInputOperands()) {
+ for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
TensorCollapseShapeOp reshapeOp =
opOperand->get().getDefiningOp<TensorCollapseShapeOp>();
if (!reshapeOp)
More information about the Mlir-commits
mailing list