[Mlir-commits] [mlir] [MLIR][Vector] Generalize DropUnitDimFromElementwiseOps to non leading / trailing dimensions. (PR #98455)

Hugo Trachino llvmlistbot at llvm.org
Fri Jul 12 02:18:07 PDT 2024


https://github.com/nujaa updated https://github.com/llvm/llvm-project/pull/98455

>From b2365584da57ac40ce99b2f380cd3e0e31166671 Mon Sep 17 00:00:00 2001
From: Hugo Trachino <hugo.trachino at huawei.com>
Date: Thu, 20 Jun 2024 10:43:23 +0100
Subject: [PATCH 1/3] [MLIR][Vector] Generalize DropUnitDimFromElementwiseOps
 to non leading / trailing dimensions. (#92934)

Generalizes `DropUnitDimFromElementwiseOps` to support inner unit
dimensions.
This change stems from improving lowering of contractionOps for Arm SME.
Where we end up with inner unit dimensions on MulOp, BroadcastOp and
TransposeOp, preventing the generation of outerproducts.
discussed
[here](https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/17?u=nujaa).

---------

Co-authored-by: Benjamin Maxwell <macdue at dueutil.tech>
---
 .../Vector/Transforms/VectorTransforms.cpp    | 55 ++++++++++---------
 .../Vector/vector-transfer-flatten.mlir       | 36 ++++++++++++
 2 files changed, 65 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index da5954b70a2ec..c7d3022eff4d3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1622,7 +1622,27 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
-/// For vectors with either leading or trailing unit dim, replaces:
+// Scalable unit dimensions are not supported. Folding such dimensions would
+// require "shifting" the scalable flag onto some other fixed-width dim (e.g.
+// vector<[1]x4xf32> -> vector<[4]xf32>). This could be implemented in the
+// future.
+static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
+  auto inVecShape = inVecTy.getShape();
+  SmallVector<int64_t> newShape;
+  SmallVector<bool> newScalableDims;
+  for (auto [dim, isScalable] :
+       llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
+    if (dim == 1 && !isScalable)
+      continue;
+
+    newShape.push_back(dim);
+    newScalableDims.push_back(isScalable);
+  }
+
+  return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
+}
+
+/// For vectors with at least an unit dim, replaces:
 ///   elementwise(a, b)
 /// with:
 ///   sc_a = shape_cast(a)
@@ -1634,20 +1654,16 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
 /// required to be rank > 1.
 ///
 /// Ex:
-/// ```
 ///  %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
 ///  %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
-/// ```
 ///
 /// gets converted to:
 ///
-/// ```
 ///  %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
 ///  %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
 ///  %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
 ///  %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
 ///  %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
-/// ```
 ///
 /// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
 /// `%cast`.
@@ -1667,42 +1683,29 @@ struct DropUnitDimFromElementwiseOps final
     // guaranteed to have identical shapes (with some exceptions such as
     // `arith.select`) and it suffices to only check one of them.
     auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
-    if (!sourceVectorType)
-      return failure();
-    if (sourceVectorType.getRank() < 2)
-      return failure();
-
-    bool hasTrailingDimUnitFixed =
-        ((sourceVectorType.getShape().back() == 1) &&
-         (!sourceVectorType.getScalableDims().back()));
-    bool hasLeadingDimUnitFixed =
-        ((sourceVectorType.getShape().front() == 1) &&
-         (!sourceVectorType.getScalableDims().front()));
-    if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
+    if (!sourceVectorType || sourceVectorType.getRank() < 2)
       return failure();
 
-    // Drop leading/trailing unit dim by applying vector.shape_cast to all
-    // operands
-    int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
-
     SmallVector<Value> newOperands;
     auto loc = op->getLoc();
     for (auto operand : op->getOperands()) {
       auto opVectorType = cast<VectorType>(operand.getType());
-      VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
+      auto newVType = dropNonScalableUnitDimFromType(opVectorType);
+      if (newVType == opVectorType)
+        return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
+
       auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
       newOperands.push_back(opSC);
     }
 
     VectorType newResultVectorType =
-        VectorType::Builder(resultVectorType).dropDim(dim);
-    // Create an updated elementwise Op without leading/trailing unit dim
+        dropNonScalableUnitDimFromType(resultVectorType);
+    // Create an updated elementwise Op without unit dim.
     Operation *elementwiseOp =
         rewriter.create(loc, op->getName().getIdentifier(), newOperands,
                         newResultVectorType, op->getAttrs());
 
-    // Restore the leading/trailing unit dim by applying vector.shape_cast
-    // to the result
+    // Restore the unit dim by applying vector.shape_cast to the result.
     rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
                                              elementwiseOp->getResult(0));
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 5fd3cbd54aa58..3a5041fca53fc 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -604,6 +604,42 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 
 // -----
 
+func.func @fold_inner_unit_dim(%arg0 : vector<8x1x3xf128>,
+                              %arg1 : vector<1x8x3xf128>) -> vector<8x3xf128> {
+   %sc_arg1 = vector.shape_cast %arg1 : vector<1x8x3xf128> to vector<8x1x3xf128>
+   %mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x3xf128>
+   %res = vector.shape_cast %mul : vector<8x1x3xf128> to vector<8x3xf128>
+   return %res : vector<8x3xf128>
+}
+
+// CHECK-LABEL: func.func @fold_inner_unit_dim(
+// CHECK-SAME:    %[[VAL_0:.*]]: vector<8x1x3xf128>,
+// CHECK-SAME:    %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> {
+// CHECK:         %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128>
+// CHECK:         %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x3xf128> to vector<8x3xf128>
+// CHECK:         %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x3xf128>
+// CHECK:         return %[[VAL_4]] : vector<8x3xf128>
+
+// -----
+
+func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
+                              %arg1 : vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
+   %sc_arg1 = vector.shape_cast %arg1 : vector<1x8x[1]x3xf128> to vector<8x1x[1]x3xf128>
+   %mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x[1]x3xf128>
+   %res = vector.shape_cast %mul : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
+   return %res : vector<8x[1]x3xf128>
+}
+
+// CHECK-LABEL: func.func @fold_inner_unit_dim_scalable(
+// CHECK-SAME:    %[[VAL_0:.*]]: vector<8x1x[1]x3xf128>,
+// CHECK-SAME:    %[[VAL_1:.*]]: vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
+// CHECK:         %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
+// CHECK:         %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]x3xf128> to vector<8x[1]x3xf128>
+// CHECK:         %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]x3xf128>
+// CHECK:         return %[[VAL_4]] : vector<8x[1]x3xf128>
+
+// -----
+
 func.func @negative_out_of_bound_transfer_read(
     %arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
   %c0 = arith.constant 0 : index

>From 3909145bc2191af858bac24d24c007861556a5d0 Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Thu, 11 Jul 2024 22:46:52 +0800
Subject: [PATCH 2/3] FIXUP: Handle all unit dims case

---
 .../Vector/Transforms/VectorTransforms.cpp        |  8 +++++++-
 .../Dialect/Vector/vector-transfer-flatten.mlir   | 15 +++++++++++++++
 2 files changed, 22 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index c7d3022eff4d3..9b837efe7662c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1622,7 +1622,8 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
-// Scalable unit dimensions are not supported. Folding such dimensions would
+// Helper function dropping unit non-scalable dimension from a VectorType.
+// Scalable unit dimensions are not dropped. Folding such dimensions would
 // require "shifting" the scalable flag onto some other fixed-width dim (e.g.
 // vector<[1]x4xf32> -> vector<[4]xf32>). This could be implemented in the
 // future.
@@ -1638,6 +1639,11 @@ static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
     newShape.push_back(dim);
     newScalableDims.push_back(isScalable);
   }
+  // All dims have been dropped, we need to return a legal shape for VectorType.
+  if (newShape.empty()) {
+    newShape.push_back(1);
+    newScalableDims.push_back(false);
+  }
 
   return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
 }
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 3a5041fca53fc..303f841e8a828 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -640,6 +640,21 @@ func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
 
 // -----
 
+func.func @fold_all_unit_dims(%arg0: vector<1x1xf32>) -> vector<1xf32> {
+  %0 = arith.mulf %arg0, %arg0 : vector<1x1xf32>
+  %res = vector.shape_cast %0 : vector<1x1xf32> to vector<1xf32>
+  return %res : vector<1xf32>
+}
+
+// CHECK-LABEL: func.func @fold_all_unit_dims(
+// CHECK-SAME:    %[[VAL_0:.*]]: vector<1x1xf32>) -> vector<1xf32>
+// CHECK:         %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32>
+// CHECK:         %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32>
+// CHECK:         %[[VAL_3:.*]] = arith.mulf %[[VAL_1]], %[[VAL_2]] : vector<1xf32>
+// CHECK:         return %[[VAL_3]] : vector<1xf32>
+
+// -----
+
 func.func @negative_out_of_bound_transfer_read(
     %arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
   %c0 = arith.constant 0 : index

>From 3b99b462043e84e5a7bfe69240aa8bc1467b0647 Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Fri, 12 Jul 2024 17:17:53 +0800
Subject: [PATCH 3/3] Document 1D edgecase

---
 .../Dialect/Vector/Transforms/VectorTransforms.cpp   | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 9b837efe7662c..b867254113971 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1622,11 +1622,11 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
-// Helper function dropping unit non-scalable dimension from a VectorType.
-// Scalable unit dimensions are not dropped. Folding such dimensions would
-// require "shifting" the scalable flag onto some other fixed-width dim (e.g.
-// vector<[1]x4xf32> -> vector<[4]xf32>). This could be implemented in the
-// future.
+// Helper function dropping unit non-scalable dimension from a VectorType
+// keeping at least 1 dimension. Scalable unit dimensions are not dropped.
+// Folding such dimensions would require "shifting" the scalable flag onto some
+// other fixed-width dim (e.g. vector<[1]x4xf32> -> vector<[4]xf32>). This could
+// be implemented in the future.
 static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
   auto inVecShape = inVecTy.getShape();
   SmallVector<int64_t> newShape;
@@ -1639,7 +1639,7 @@ static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
     newShape.push_back(dim);
     newScalableDims.push_back(isScalable);
   }
-  // All dims have been dropped, we need to return a legal shape for VectorType.
+  // All dims have been dropped, return vector<1xeType>.
   if (newShape.empty()) {
     newShape.push_back(1);
     newScalableDims.push_back(false);



More information about the Mlir-commits mailing list