[Mlir-commits] [mlir] [mlir][vector] Add pattern to drop unit dim from elementwise(a, b)) (PR #74817)

Diego Caballero llvmlistbot at llvm.org
Wed Dec 13 01:39:34 PST 2023


================
@@ -1446,6 +1446,90 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
+/// Replace:
+///   elementwise(a, b)
+/// with:
+///   sc_a = shape_cast(a)
+///   sc_b = shape_cast(b)
+///   res = elementwise(sc_a, sc_b)
+///   return shape_cast(res)
+/// for which `a` and `b` are vectors of rank > 2 and have unit leading and/or
+/// trailing dimension.
+///
+/// Ex:
+/// ```
+///  %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
+///  %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
+/// ```
+///
+/// gets converted to:
+///
+/// ```
+///  %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
+///  %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
+///  %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
+///  %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
+///  %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
+/// ```
+///
+/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
+/// `%cast`.
+struct DropUnitDimFromElementwiseOps final
+    : public OpTraitRewritePattern<OpTrait::Elementwise> {
+  using OpTraitRewritePattern::OpTraitRewritePattern;
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    if (op->getNumResults() != 1)
+      return failure();
+
+    // Check the pre-condiitions. For `Elementwise` Ops all operands
+    // are guaranteed to have identical shapes and it suffices to only check the
+    // first one.
+    auto op1 = op->getOperands()[0];
+    auto sourceVectorType = dyn_cast<VectorType>(op1.getType());
+    if (!sourceVectorType)
+      return failure();
+
+    if (sourceVectorType.getRank() < 2)
+      return failure();
+
+    bool trailingDimUnitFixed = ((sourceVectorType.getShape().back() == 1) &&
+                                 (!sourceVectorType.getScalableDims().back()));
+    bool leadDimUnitFixed = ((sourceVectorType.getShape().front() == 1) &&
----------------
dcaballe wrote:

nit: -> ~`hasTrailingUnitDim`, `hasLeadingUnitDim`?

https://github.com/llvm/llvm-project/pull/74817


More information about the Mlir-commits mailing list