[Mlir-commits] [mlir] [mlir][vector] Add pattern to drop unit dim from elementwise(a, b)) (PR #74817)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Dec 13 03:28:54 PST 2023
================
@@ -1446,6 +1446,92 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
}
};
+/// Replace:
+/// elementwise(a, b)
+/// with:
+/// sc_a = shape_cast(a)
+/// sc_b = shape_cast(b)
+/// res = elementwise(sc_a, sc_b)
+/// return shape_cast(res)
+/// for which `a` and `b` are vectors of rank > 2 and have unit leading and/or
+/// trailing dimension.
+///
+/// Ex:
+/// ```
+/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
+/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
+/// ```
+///
+/// gets converted to:
+///
+/// ```
+/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
+/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
+/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
+/// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
+/// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
+/// ```
+///
+/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
+/// `%cast`.
+struct DropUnitDimFromElementwiseOps final
+ : public OpTraitRewritePattern<OpTrait::Elementwise> {
+ using OpTraitRewritePattern::OpTraitRewritePattern;
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ if (op->getNumResults() != 1)
+ return failure();
+
+ // Check the pre-condiitions. For `Elementwise` Ops all operands
+ // are guaranteed to have identical shapes and it suffices to only check the
+ // first one.
+ auto op1 = op->getOperands()[0];
+ auto sourceVectorType = dyn_cast<VectorType>(op1.getType());
+ if (!sourceVectorType)
+ return failure();
+
+ if (sourceVectorType.getRank() < 2)
+ return failure();
+
+ bool hasTrailingDimUnitFixed =
+ ((sourceVectorType.getShape().back() == 1) &&
+ (!sourceVectorType.getScalableDims().back()));
+ bool hasLeadingDimUnitFixed =
+ ((sourceVectorType.getShape().front() == 1) &&
+ (!sourceVectorType.getScalableDims().front()));
+ if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
+ return failure();
+
+ // Drop leading/trailing unit dim by applying vector.shape_cast to all
+ // operands
+ auto elTy = sourceVectorType.getElementType();
+ VectorType newVType =
+ hasLeadingDimUnitFixed
+ ? VectorType::get(sourceVectorType.getShape().drop_front(1), elTy,
+ sourceVectorType.getScalableDims().drop_front(1))
+ : VectorType::get(sourceVectorType.getShape().drop_back(1), elTy,
+ sourceVectorType.getScalableDims().drop_back(1));
----------------
banach-space wrote:
Nice, thanks!
https://github.com/llvm/llvm-project/pull/74817
More information about the Mlir-commits
mailing list