[Mlir-commits] [mlir] [MLIR][Vector]Generalize DropUnitDimFromElementwiseOps (PR #92934)
Hugo Trachino
llvmlistbot at llvm.org
Fri May 31 04:08:00 PDT 2024
================
@@ -1652,42 +1668,30 @@ struct DropUnitDimFromElementwiseOps final
// guaranteed to have identical shapes (with some exceptions such as
// `arith.select`) and it suffices to only check one of them.
auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
- if (!sourceVectorType)
- return failure();
- if (sourceVectorType.getRank() < 2)
- return failure();
-
- bool hasTrailingDimUnitFixed =
- ((sourceVectorType.getShape().back() == 1) &&
- (!sourceVectorType.getScalableDims().back()));
- bool hasLeadingDimUnitFixed =
- ((sourceVectorType.getShape().front() == 1) &&
- (!sourceVectorType.getScalableDims().front()));
- if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
+ if (!sourceVectorType || sourceVectorType.getRank() < 2)
return failure();
- // Drop leading/trailing unit dim by applying vector.shape_cast to all
- // operands
- int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
-
SmallVector<Value> newOperands;
auto loc = op->getLoc();
for (auto operand : op->getOperands()) {
auto opVectorType = cast<VectorType>(operand.getType());
- VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
- auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
+ auto newVType = dropNonScalableUnitDimType(opVectorType);
+ if (failed(newVType)) {
+ return failure();
+ }
----------------
nujaa wrote:
👍
https://github.com/llvm/llvm-project/pull/92934
More information about the Mlir-commits
mailing list