[Mlir-commits] [mlir] [mlir][vector] Better transfer_read(transfer_write) canonicalization (PR #72617)

Matthias Springer llvmlistbot at llvm.org
Thu Nov 16 23:38:01 PST 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/72617

This change improves the canonicalization of `transfer_read(transfer_write)` IR patterns where the two transfer ops access the same chunk of the shaped value (store-load forwarding). The existing rewrite pattern did not support cases where the two transfer ops operate on vectors of different rank (i.e., different rank-reduced/extended unit dims).

The previous pattern generated a combination of `vector.transpose` and `vector.broadcast`. The new pattern generates a combination of `vector.transpose`, `vector.broadcast` and `vector.extract`. In cases where no `vector.extract` is needed, other canonicalization patterns/foldings simplify the IR such the same IR as with the previous pattern is produced.

Depends on #72594 and #72616. Review only the top commit.


>From a4b57fdb77f637f82bb556d74f146138bbf1fab8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 17 Nov 2023 11:10:50 +0900
Subject: [PATCH 1/3] [mlir][vector] Modernize `vector.transpose` op

* Declare arguments/results with `let` statements.
* Rename `transp` to `permutation`.
* Change type of `transp` from `I64ArrayAttr` to `DenseI64ArrayAttr`.
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 17 +++--
 mlir/include/mlir/IR/AffineMap.h              |  2 +
 .../VectorToArmSME/VectorToArmSME.cpp         |  7 +-
 .../Conversion/VectorToGPU/VectorToGPU.cpp    |  7 +-
 .../Dialect/Arith/Transforms/IntNarrowing.cpp |  2 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 64 +++++++------------
 .../Transforms/LowerVectorTranspose.cpp       |  4 +-
 .../Vector/Transforms/VectorTransforms.cpp    | 20 +++---
 .../Vector/Transforms/VectorUnroll.cpp        |  3 +-
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp |  7 +-
 mlir/lib/IR/AffineMap.cpp                     |  6 ++
 11 files changed, 59 insertions(+), 80 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 02e80a6446dfb24..1397d4caf1d9d61 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2436,14 +2436,13 @@ def Vector_TransposeOp :
   Vector_Op<"transpose", [Pure,
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
     PredOpTrait<"operand and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
-    Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$transp)>,
-    Results<(outs AnyVectorOfAnyRank:$result)> {
+                 TCresVTEtIsSameAsOpBase<0, 0>>]> {
   let summary = "vector transpose operation";
   let description = [{
     Takes a n-D vector and returns the transposed n-D vector defined by
     the permutation of ranks in the n-sized integer array attribute (in case
     of 0-D vectors the array attribute must be empty).
+
     In the operation
 
     ```mlir
@@ -2452,7 +2451,7 @@ def Vector_TransposeOp :
       to vector<d_trans[0] x .. x d_trans[n-1] x f32>
     ```
 
-    the transp array [i_1, .., i_n] must be a permutation of [0, .., n-1].
+    the `permutation` array [i_1, .., i_n] must be a permutation of [0, .., n-1].
 
     Example:
 
@@ -2464,8 +2463,13 @@ def Vector_TransposeOp :
                           [c, f] ]
     ```
   }];
+
+  let arguments = (ins AnyVectorOfAnyRank:$vector,
+                       DenseI64ArrayAttr:$permutation);
+  let results = (outs AnyVectorOfAnyRank:$result);
+
   let builders = [
-    OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$transp)>
+    OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$permutation)>
   ];
   let extraClassDeclaration = [{
     VectorType getSourceVectorType() {
@@ -2474,10 +2478,9 @@ def Vector_TransposeOp :
     VectorType getResultVectorType() {
       return ::llvm::cast<VectorType>(getResult().getType());
     }
-    void getTransp(SmallVectorImpl<int64_t> &results);
   }];
   let assemblyFormat = [{
-    $vector `,` $transp attr-dict `:` type($vector) `to` type($result)
+    $vector `,` $permutation attr-dict `:` type($vector) `to` type($result)
   }];
   let hasCanonicalizer = 1;
   let hasFolder = 1;
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 713aef767edf669..981f3d392cbc98c 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -103,6 +103,8 @@ class AffineMap {
   /// (i.e. `[1,1,2]` is an invalid permutation).
   static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
                                      MLIRContext *context);
+  static AffineMap getPermutationMap(ArrayRef<int64_t> permutation,
+                                     MLIRContext *context);
 
   /// Returns an affine map with `numDims` input dimensions and results
   /// specified by `targets`.
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 953a465c18de69f..01c782676068d9a 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -436,12 +436,9 @@ struct TransposeOpToArmSMELowering
     if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
       return failure();
 
-    SmallVector<int64_t> transp;
-    for (auto attr : transposeOp.getTransp())
-      transp.push_back(cast<IntegerAttr>(attr).getInt());
-
     // Bail unless this is a true 2-D matrix transpose.
-    if (transp[0] != 1 || transp[1] != 0)
+    ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+    if (permutation[0] != 1 || permutation[1] != 0)
       return failure();
 
     OpBuilder::InsertionGuard g(rewriter);
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 1126c2c20758c7a..429d1137b6f3781 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -473,13 +473,8 @@ struct CombineTransferReadOpTranspose final
     if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
       return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
 
-    SmallVector<int64_t, 2> perm;
-    op.getTransp(perm);
-    SmallVector<unsigned, 2> permU;
-    for (int64_t o : perm)
-      permU.push_back(unsigned(o));
     AffineMap permutationMap =
-        AffineMap::getPermutationMap(permU, op.getContext());
+        AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
     AffineMap newMap =
         permutationMap.compose(transferReadOp.getPermutationMap());
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 1084fbc890053b9..79fabd6ed2e99a2 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -712,7 +712,7 @@ struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
     VectorType newTy =
         origTy.cloneWith(origTy.getShape(), ext->getInElementType());
     Value newTranspose = rewriter.create<vector::TransposeOp>(
-        op.getLoc(), newTy, ext->getIn(), op.getTransp());
+        op.getLoc(), newTy, ext->getIn(), op.getPermutation());
     ext->recreateAndReplace(rewriter, op, newTranspose);
     return success();
   }
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 95f49fa32bc0ae2..c7b74701fdbc8f2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1456,9 +1456,8 @@ LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
 
   if (!nextTransposeOp)
     return failure();
-  auto permutation = extractVector<unsigned>(nextTransposeOp.getTransp());
-  AffineMap m = inversePermutation(
-      AffineMap::getPermutationMap(permutation, extractOp.getContext()));
+  AffineMap m = inversePermutation(AffineMap::getPermutationMap(
+      nextTransposeOp.getPermutation(), extractOp.getContext()));
   extractPosition = applyPermutationMap(m, ArrayRef(extractPosition));
   return success();
 }
@@ -5376,20 +5375,20 @@ LogicalResult TypeCastOp::verify() {
 //===----------------------------------------------------------------------===//
 
 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
-                                Value vector, ArrayRef<int64_t> transp) {
+                                Value vector, ArrayRef<int64_t> permutation) {
   VectorType vt = llvm::cast<VectorType>(vector.getType());
   SmallVector<int64_t, 4> transposedShape(vt.getRank());
   SmallVector<bool, 4> transposedScalableDims(vt.getRank());
-  for (unsigned i = 0; i < transp.size(); ++i) {
-    transposedShape[i] = vt.getShape()[transp[i]];
-    transposedScalableDims[i] = vt.getScalableDims()[transp[i]];
+  for (unsigned i = 0; i < permutation.size(); ++i) {
+    transposedShape[i] = vt.getShape()[permutation[i]];
+    transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
   }
 
   result.addOperands(vector);
   result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
                                   transposedScalableDims));
-  result.addAttribute(TransposeOp::getTranspAttrName(result.name),
-                      builder.getI64ArrayAttr(transp));
+  result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
+                      builder.getDenseI64ArrayAttr(permutation));
 }
 
 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
@@ -5401,13 +5400,12 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
 
   // Eliminate identity transpose ops. This happens when the dimensions of the
   // input vector remain in their original order after the transpose operation.
-  SmallVector<int64_t, 4> transp;
-  getTransp(transp);
+  ArrayRef<int64_t> perm = getPermutation();
 
   // Check if the permutation of the dimensions contains sequential values:
   // {0, 1, 2, ...}.
-  for (int64_t i = 0, e = transp.size(); i < e; i++) {
-    if (transp[i] != i)
+  for (int64_t i = 0, e = perm.size(); i < e; i++) {
+    if (perm[i] != i)
       return {};
   }
 
@@ -5421,20 +5419,19 @@ LogicalResult vector::TransposeOp::verify() {
   if (vectorType.getRank() != rank)
     return emitOpError("vector result rank mismatch: ") << rank;
   // Verify transposition array.
-  auto transpAttr = getTransp().getValue();
-  int64_t size = transpAttr.size();
+  ArrayRef<int64_t> perm = getPermutation();
+  int64_t size = perm.size();
   if (rank != size)
     return emitOpError("transposition length mismatch: ") << size;
   SmallVector<bool, 8> seen(rank, false);
-  for (const auto &ta : llvm::enumerate(transpAttr)) {
-    int64_t i = llvm::cast<IntegerAttr>(ta.value()).getInt();
-    if (i < 0 || i >= rank)
-      return emitOpError("transposition index out of range: ") << i;
-    if (seen[i])
-      return emitOpError("duplicate position index: ") << i;
-    seen[i] = true;
-    if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
-      return emitOpError("dimension size mismatch at: ") << i;
+  for (const auto &ta : llvm::enumerate(perm)) {
+    if (ta.value() < 0 || ta.value() >= rank)
+      return emitOpError("transposition index out of range: ") << ta.value();
+    if (seen[ta.value()])
+      return emitOpError("duplicate position index: ") << ta.value();
+    seen[ta.value()] = true;
+    if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
+      return emitOpError("dimension size mismatch at: ") << ta.value();
   }
   return success();
 }
@@ -5452,13 +5449,6 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
 
   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
                                 PatternRewriter &rewriter) const override {
-    // Wrapper around vector::TransposeOp::getTransp() for cleaner code.
-    auto getPermutation = [](vector::TransposeOp transpose) {
-      SmallVector<int64_t, 4> permutation;
-      transpose.getTransp(permutation);
-      return permutation;
-    };
-
     // Composes two permutations: result[i] = permutation1[permutation2[i]].
     auto composePermutations = [](ArrayRef<int64_t> permutation1,
                                   ArrayRef<int64_t> permutation2) {
@@ -5475,12 +5465,11 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
       return failure();
 
     SmallVector<int64_t, 4> permutation = composePermutations(
-        getPermutation(parentTransposeOp), getPermutation(transposeOp));
+        parentTransposeOp.getPermutation(), transposeOp.getPermutation());
     // Replace 'transposeOp' with a new transpose operation.
     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
         transposeOp, transposeOp.getResult().getType(),
-        parentTransposeOp.getVector(),
-        vector::getVectorSubscriptAttr(rewriter, permutation));
+        parentTransposeOp.getVector(), permutation);
     return success();
   }
 };
@@ -5539,8 +5528,7 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
 
     // Get the transpose permutation and apply it to the vector.create_mask or
     // vector.constant_mask operands.
-    SmallVector<int64_t> permutation;
-    transpOp.getTransp(permutation);
+    ArrayRef<int64_t> permutation = transpOp.getPermutation();
 
     if (createMaskOp) {
       auto maskOperands = createMaskOp.getOperands();
@@ -5572,10 +5560,6 @@ void vector::TransposeOp::getCanonicalizationPatterns(
               TransposeFolder, FoldTransposeSplat>(context);
 }
 
-void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
-  populateFromInt64AttrArray(getTransp(), results);
-}
-
 //===----------------------------------------------------------------------===//
 // ConstantMaskOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index dee786007c80630..97f6caca1b25ccc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -327,9 +327,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
     VectorType resType = op.getResultVectorType();
 
     // Set up convenience transposition table.
-    SmallVector<int64_t> transp;
-    for (auto attr : op.getTransp())
-      transp.push_back(cast<IntegerAttr>(attr).getInt());
+    ArrayRef<int64_t> transp = op.getPermutation();
 
     if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
         succeeded(isTranspose2DSlice(op)))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 713f9cb72c82cec..a20c8aeeb6f7108 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -212,8 +212,7 @@ struct CombineContractABTranspose final
       if (!transposeOp)
         continue;
       AffineMap permutationMap = AffineMap::getPermutationMap(
-          extractVector<unsigned>(transposeOp.getTransp()),
-          contractOp.getContext());
+          transposeOp.getPermutation(), contractOp.getContext());
       map = inversePermutation(permutationMap).compose(map);
       *operand = transposeOp.getVector();
       changed = true;
@@ -279,13 +278,13 @@ struct CombineContractResultTranspose final
 
     // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
     // To index into A in contract, we need revert(f)(g(C)) -> A.
-    auto accTMap = AffineMap::getPermutationMap(
-        extractVector<unsigned>(accTOp.getTransp()), context);
+    auto accTMap =
+        AffineMap::getPermutationMap(accTOp.getPermutation(), context);
 
     // Contract performs g(C) -> D. Result transpose performs h(D) -> E.
     // To index into E in contract, we need h(g(C)) -> E.
-    auto resTMap = AffineMap::getPermutationMap(
-        extractVector<unsigned>(resTOp.getTransp()), context);
+    auto resTMap =
+        AffineMap::getPermutationMap(resTOp.getPermutation(), context);
     auto combinedResMap = resTMap.compose(contractMap);
 
     // The accumulator and result share the same indexing map. So they should be
@@ -490,7 +489,7 @@ struct ReorderElementwiseOpsOnTranspose final
 
     // Make sure all operands are transpose/constant ops and collect their
     // transposition maps.
-    SmallVector<ArrayAttr> transposeMaps;
+    SmallVector<ArrayRef<int64_t>> transposeMaps;
     transposeMaps.reserve(op->getNumOperands());
     // Record the initial type before transposition. We'll use its shape later.
     // Any type will do here as we will check all transpose maps are the same.
@@ -498,7 +497,7 @@ struct ReorderElementwiseOpsOnTranspose final
     for (Value operand : op->getOperands()) {
       auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
       if (transposeOp) {
-        transposeMaps.push_back(transposeOp.getTransp());
+        transposeMaps.push_back(transposeOp.getPermutation());
         srcType = transposeOp.getSourceVectorType();
       } else if (!matchPattern(operand, m_Constant())) {
         return failure();
@@ -517,7 +516,7 @@ struct ReorderElementwiseOpsOnTranspose final
 
     // If there are constant operands, we need to insert inverse transposes for
     // them. Calculate the inverse order first.
-    auto order = extractVector<unsigned>(transposeMaps.front());
+    auto order = transposeMaps.front();
     SmallVector<int64_t> invOrder(order.size());
     for (int i = 0, e = order.size(); i < e; ++i)
       invOrder[order[i]] = i;
@@ -532,8 +531,7 @@ struct ReorderElementwiseOpsOnTranspose final
             srcType.getShape(),
             cast<VectorType>(operand.getType()).getElementType());
         srcValues.push_back(rewriter.create<vector::TransposeOp>(
-            operand.getLoc(), vectorType, operand,
-            rewriter.getI64ArrayAttr(invOrder)));
+            operand.getLoc(), vectorType, operand, invOrder));
       }
     }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 4cfac7de29ee76f..78b041255443c30 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -537,8 +537,7 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
     // Prepare the result vector;
     Value result = rewriter.create<arith::ConstantOp>(
         loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
-    SmallVector<int64_t> permutation;
-    transposeOp.getTransp(permutation);
+    ArrayRef<int64_t> permutation = transposeOp.getPermutation();
 
     // Unroll the computation.
     for (SmallVector<int64_t> elementOffsets :
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 467a521f9eada96..48cd67ad86c63fb 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -87,14 +87,11 @@ mlir::vector::isTranspose2DSlice(vector::TransposeOp op) {
   if (srcGtOneDims.size() != 2)
     return failure();
 
-  SmallVector<int64_t> transp;
-  for (auto attr : op.getTransp())
-    transp.push_back(cast<IntegerAttr>(attr).getInt());
-
   // Check whether the two source vector dimensions that are greater than one
   // must be transposed with each other so that we can apply one of the 2-D
   // transpose pattens. Otherwise, these patterns are not applicable.
-  if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp))
+  if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1],
+                                  op.getPermutation()))
     return failure();
 
   return std::pair<int, int>(srcGtOneDims[0], srcGtOneDims[1]);
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 93a8d048e0a61d5..80a26a595edee0a 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -236,6 +236,12 @@ AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
   assert(permutationMap.isPermutation() && "Invalid permutation vector");
   return permutationMap;
 }
+AffineMap AffineMap::getPermutationMap(ArrayRef<int64_t> permutation,
+                                       MLIRContext *context) {
+  SmallVector<unsigned> perm = llvm::map_to_vector(
+      permutation, [](int64_t i) { return static_cast<unsigned>(i); });
+  return AffineMap::getPermutationMap(perm, context);
+}
 
 AffineMap AffineMap::getMultiDimMapWithTargets(unsigned numDims,
                                                ArrayRef<unsigned> targets,

>From 5b48e0b0f5eba3f2b0029617ad7313aa10a0c266 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 17 Nov 2023 16:29:19 +0900
Subject: [PATCH 2/3] [mlir][vector] Add extract(transpose(broadcast(x)))
 canonicalization

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  9 ++-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 73 ++++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir    | 14 ++++
 3 files changed, 93 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 1397d4caf1d9d61..49860cadcd12c26 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -386,10 +386,15 @@ def Vector_BroadcastOp :
       return ::llvm::cast<VectorType>(getVector().getType());
     }
 
-    /// Return the dimensions of the result vector that were formerly ones in the
-    /// source tensor and thus correspond to "dim-1" broadcasting.
+    /// Return the dimensions of the result vector that were formerly ones in
+    /// the source vector and thus correspond to "dim-1" broadcasting.
     llvm::SetVector<int64_t> computeBroadcastedUnitDims();
 
+    /// Return the dimensions of the result vector that were newly added to the
+    /// source vector via rank extension. These are all the dimensions that were
+    /// not "dim-1" broadcasted.
+    llvm::SetVector<int64_t> computeRankExtendedDims();
+
     /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
     /// `broadcastedDims` dimensions in the dstShape are broadcasted.
     /// This requires (and asserts) that the broadcast is free of dim-1
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c7b74701fdbc8f2..957143d6c13e9e4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1897,6 +1897,62 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
   }
 };
 
+/// Canonicalize extract(transpose(broadcast))) constructs, where the broadcast
+/// adds a new dimension and the extraction removes it again.
+class ExtractOpTransposedBroadcastDim final
+    : public OpRewritePattern<ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    // Skip vector.extract ops that do not remove any dimensions.
+    if (extractOp.getNumIndices() == 0)
+      return failure();
+    // Look for extract(transpose(broadcast(x))) pattern.
+    auto transposeOp =
+        extractOp.getVector().getDefiningOp<vector::TransposeOp>();
+    if (!transposeOp || transposeOp.getPermutation().empty())
+      return failure();
+    auto broadcastOp =
+        transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
+    if (!broadcastOp)
+      return failure();
+    // Check if the first dimension that is being removed by the vector.extract
+    // was added by the vector.broadcast.
+    int64_t removedDim = transposeOp.getPermutation()[0];
+    llvm::SetVector<int64_t> rankExtendedDims =
+        broadcastOp.computeRankExtendedDims();
+    if (!rankExtendedDims.contains(removedDim))
+      return failure();
+
+    // 1. Create new vector.broadcast without the removed dimension.
+    SmallVector<int64_t> newBroadcastShape(
+        broadcastOp.getResultVectorType().getShape());
+    newBroadcastShape.erase(newBroadcastShape.begin() + removedDim);
+    auto newBroadcast = rewriter.create<vector::BroadcastOp>(
+        broadcastOp.getLoc(),
+        VectorType::get(newBroadcastShape,
+                        broadcastOp.getResultVectorType().getElementType()),
+        broadcastOp.getSource());
+
+    // 2. Create new vector.transpose.
+    SmallVector<int64_t> newPermutation;
+    for (int64_t dim : transposeOp.getPermutation().drop_front())
+      newPermutation.push_back(dim < transposeOp.getPermutation()[0] ? dim
+                                                                     : dim - 1);
+    auto newTranspose = rewriter.create<vector::TransposeOp>(
+        transposeOp.getLoc(), newBroadcast, newPermutation);
+
+    // 3. Create new vector.extract without the outermost dimension.
+    SmallVector<OpFoldResult> mixedPositions = extractOp.getMixedPosition();
+    rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+        extractOp, newTranspose, ArrayRef(mixedPositions).drop_front());
+
+    return success();
+  }
+};
+
 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
 class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
 public:
@@ -2062,7 +2118,8 @@ LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
   results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
-              ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+              ExtractOpFromBroadcast, ExtractOpTransposedBroadcastDim,
+              ExtractOpFromCreateMask>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
 }
 
@@ -2112,6 +2169,20 @@ llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
                                       getResultVectorType().getShape());
 }
 
+llvm::SetVector<int64_t> BroadcastOp::computeRankExtendedDims() {
+  llvm::SetVector<int64_t> broadcastedUnitDims = computeBroadcastedUnitDims();
+  llvm::SetVector<int64_t> result;
+  auto vecSrcType = dyn_cast<VectorType>(getSourceType());
+  int64_t rankDiff =
+      vecSrcType ? getResultVectorType().getRank() - vecSrcType.getRank()
+                 : getResultVectorType().getRank();
+  for (int64_t i = 0; i < rankDiff; ++i) {
+    if (!broadcastedUnitDims.contains(i))
+      result.insert(i);
+  }
+  return result;
+}
+
 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
 /// `broadcastedDims` dimensions in the dstShape are broadcasted.
 /// This requires (and asserts) that the broadcast is free of dim-1
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d341..a6b4f7f2717da81 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2524,3 +2524,17 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
       tensor<4x4x4xf32>, vector<1x100x4x5xf32>
   return %r : vector<1x100x4x5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @extract_of_transposed_broadcast_dim(
+//  CHECK-SAME:     %[[arg0:.*]]: vector<4x1xf32>
+//       CHECK:   %[[bc:.*]] = vector.broadcast %[[arg0]] : vector<4x1xf32> to vector<100x5x4x1xf32>
+//       CHECK:   %[[tp:.*]] = vector.transpose %[[bc]], [3, 0, 2, 1] : vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
+//       CHECK:   return %[[tp]]
+func.func @extract_of_transposed_broadcast_dim(%arg0: vector<4x1xf32>) -> vector<1x100x4x5xf32> {
+  %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<100x5x1x4x1xf32>
+  %1 = vector.transpose %0, [2, 4, 0, 3, 1] : vector<100x5x1x4x1xf32> to vector<1x1x100x4x5xf32>
+  %2 = vector.extract %1[0] : vector<1x100x4x5xf32> from vector<1x1x100x4x5xf32>
+  return %2 : vector<1x100x4x5xf32>
+}

>From 1d58592e08d61d6875023088c2e4b006a34a895f Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 17 Nov 2023 16:31:58 +0900
Subject: [PATCH 3/3] [mlir][vector] Better `transfer_read(transfer_write)`
 canonicalization

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 214 +++++++++++++++------
 mlir/test/Dialect/Vector/canonicalize.mlir |  10 +-
 2 files changed, 164 insertions(+), 60 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 957143d6c13e9e4..cf7c3c6c1a395ac 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4077,37 +4077,43 @@ void TransferReadOp::getEffects(
 }
 
 namespace {
-/// Store to load forwarding for transfer operations with permuation maps.
-/// Even if the permutation maps are different we can still propagate the store
-/// into the load if the size of the dimensions read and written match. Then we
-/// can replace the transfer_read + transfer_write by vector.broadcast and
-/// vector.transpose.
-/// Example:
-/// ```
-/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
-///  {in_bounds = [true, true],
-///   permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
-///   vector<4x1xf32>, tensor<4x4x4xf32>
-///  %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
-///   {in_bounds = [true, true, true, true],
-///   permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
-///   tensor<4x4x4xf32>, vector<1x100x4x5xf32>
-/// ```
-/// To:
+/// Store to load forwarding for transfer operations with permutation maps.
+/// Even if the permutation maps and/or the rank of the read/written vectors are
+/// different, we can still propagate the store into the load if the accessed
+/// chunk of the shaped value matches.
+///
+/// The vector.transfer_read op is replaced by 3 ops:
+/// 1. A broadcast of the written vector with all broadcast dims of the reading
+///    op and unit dims for all shaped value dimensions that are not transfer
+///    dimensions of the writing op.
+/// 2. A transposition of the broadcasted value to account for differences
+///    in the permutation maps of the reading/writing op.
+/// 3. An extraction that removes shaped value dimensions that are not transfer
+///    dimensions of the reading op.
+///
+/// Running example:
 /// ```
-/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
-/// %r = vector.transpose %0, [3, 0, 2, 1] :
-///   vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
+/// %0 = vector.transfer_write %vec to %s[%a, %b, %c, %d, %e, %f]
+///     {permutation_map = affine_map<(d0, d1, d2, d3, d4, d5)
+///                                   -> (d2, d1, d4, d5)>}
+///     : vector<5x6x7x8xf32>, tensor<?x?x?x?x?x?xf32>
+/// %1 = vector.transfer_read %0[%a, %b, %c, %d, %e, %f]
+///     {permutation_map = affine_map<(d0, d1, d2, d3, d4, d5)
+///                                   -> (d1, d2, 0, d4, 0, d5, d0)>}
+///     : tensor<?x?x?x?x?x?xf32>, vector<6x5x100x7x200x8x1xf32>
 /// ```
-struct TransferReadAfterWriteToBroadcast
-    : public OpRewritePattern<TransferReadOp> {
+struct TransferReadAfterWrite : public OpRewritePattern<TransferReadOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(TransferReadOp readOp,
                                 PatternRewriter &rewriter) const override {
+    Location loc = readOp.getLoc();
     if (readOp.hasOutOfBoundsDim() ||
         !llvm::isa<RankedTensorType>(readOp.getShapedType()))
       return failure();
+    if (readOp.getShapedType().getElementType() !=
+        readOp.getVectorType().getElementType())
+      return failure();
     auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
     if (!defWrite)
       return failure();
@@ -4116,42 +4122,140 @@ struct TransferReadAfterWriteToBroadcast
     if (readOp.getTransferChunkAccessed() !=
         defWrite.getTransferChunkAccessed())
       return failure();
-    // TODO: Support cases where a dim is explicitly written but implicitly
-    // read (i.e., a unit dim that is rank reduced).
-    if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
-        getUnusedDimsBitVector({defWrite.getPermutationMap()}))
-      return failure();
-    if (readOp.getIndices() != defWrite.getIndices() ||
-        readOp.getMask() != defWrite.getMask())
+    if (readOp.getIndices() != defWrite.getIndices())
       return failure();
-    Value vec = defWrite.getVector();
-    // TODO: loop through the chain of transfer_write if we can prove that they
-    // don't overlap with the transfer_read. This requires improving
-    // `isDisjointTransferIndices` helper.
-    AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
-    AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
-    AffineMap map = readMap.compose(writeMap);
-    if (map.getNumResults() == 0)
+    Type elementType = readOp.getVectorType().getElementType();
+    if (elementType != defWrite.getVectorType().getElementType())
       return failure();
-    // Calculate the permutation to apply to go from the vector stored to the
-    // vector read.
-    SmallVector<unsigned> permutation;
-    if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
+    if (defWrite.getShapedType().getElementType() !=
+        defWrite.getVectorType().getElementType())
       return failure();
 
-    Location loc = readOp.getLoc();
-    // Calculate the broadcast shape by applying the reverse permutation to the
-    // final shape we want.
-    ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
-    SmallVector<int64_t> broadcastShape(destShape.size());
-    for (const auto &pos : llvm::enumerate(permutation))
-      broadcastShape[pos.value()] = destShape[pos.index()];
-    VectorType broadcastedType = VectorType::get(
-        broadcastShape, defWrite.getVectorType().getElementType());
-    vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
-    SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
-    rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
-                                                     transposePerm);
+    // 1. Add rank-reduced unit dimensions and broadcast dimension to input
+    //    vector %vec. Broadcast dimensions are added at the beginning, followed
+    //    by rank-reduced unit dims, followed by the dimensions of %vec.
+    //
+    // %bc = vector.broadcast %vec
+    //     : vector<5x6x7x8xf32> to vector<100x200x1x1x5x6x7x8xf32>
+    //                                      |   |   \|
+    //                               broadcast dims  |
+    //                                               |
+    //                          rank-reduced dims (corresponding to %a and %d)
+
+    // Gather broadcast dimensions of the transfer_read.
+    SmallVector<int64_t> broadcastedShape;
+    int64_t numBroadcastDims = 0;
+    for (int64_t i = 0, e = readOp.getTransferRank(); i < e; ++i) {
+      if (readOp.isBroadcastDim(i)) {
+        broadcastedShape.push_back(readOp.getVectorType().getDimSize(i));
+        ++numBroadcastDims;
+      }
+    }
+    // Append unit dims for rank-reduced (unused) dimensions in the
+    // transfer_write.
+    // Note: `getLeadingShapedRank` is a misnomer: the dimensions that do not
+    // participate in the transfer are not necessarily leading dimensions.
+    broadcastedShape.append(defWrite.getLeadingShapedRank(), 1);
+    // Append input vector (%vec) shape.
+    llvm::append_range(broadcastedShape, defWrite.getVectorType().getShape());
+    // Emit vector.broadcast op.
+    Value broadcasted = rewriter.create<vector::BroadcastOp>(
+        loc, VectorType::get(broadcastedShape, elementType),
+        defWrite.getVector());
+
+    // 2. Transpose the broadcasted vector. Dimensions that are not needed must
+    //    be placed at the beginning (because vector.extract can remove only
+    //    leading dimensions).
+
+    // Build a mapping (`shapedDimToVecDim`) from shaped value dims to dims of
+    // the broadcasted vector. This is essentially an inverted version of the
+    // transfer_write permutation map that takes into account the newly added
+    // unit dims.
+    //                                                     %b    %f
+    //                                                       \    |
+    // Example: broadcasted vector type: vector<100x200x1x1x5x6x7x8xf32>
+    //                                                 /  |  \   \
+    //                                                /  %d   |   %e
+    //                                              %a        %c
+    //          mapping = [2, 5, 4, 3, 6, 7]
+
+    // Initialize the mapping with -1.
+    SmallVector<int64_t> shapedDimToVecDim(defWrite.getShapedType().getRank(),
+                                           -1);
+    // Fill in the dimensions from the inverted transfer_write permutation map.
+    int64_t numUnitDims = defWrite.getLeadingShapedRank();
+    for (const auto &it :
+         llvm::enumerate(defWrite.getPermutationMap().getResults())) {
+      shapedDimToVecDim[cast<AffineDimExpr>(it.value()).getPosition()] =
+          it.index() + numUnitDims + numBroadcastDims;
+    }
+    // Fill in missing unused dims (of the transfer_write) with the broadcasted
+    // unit dims (which are placed right after the broadcast dims).
+    int64_t nextUnitDim = numBroadcastDims;
+    for (int64_t i = 0, e = shapedDimToVecDim.size(); i < e; ++i) {
+      if (shapedDimToVecDim[i] == -1)
+        shapedDimToVecDim[i] = nextUnitDim++;
+    }
+    assert(nextUnitDim == numBroadcastDims + numUnitDims &&
+           "unexpected number of unit dims");
+
+    // Compute permutation. All dims that are not needed by the transfer_read
+    // are placed at the beginning.
+    SmallVector<int64_t> permutation(broadcastedShape.size(), -1);
+    // Helper data structure to keep track of dims that were not used yet.
+    SmallVector<int64_t> remainingDims =
+        llvm::to_vector(llvm::seq<int64_t>(0, broadcastedShape.size()));
+    int64_t numUnneededDims =
+        broadcastedShape.size() - readOp.getVectorType().getRank();
+    int64_t nextBroadcastDim = 0;
+    for (int64_t i = 0, e = readOp.getVectorType().getRank(); i < e; ++i) {
+      if (readOp.isBroadcastDim(i)) {
+        // This transfer_read result dim is a broadcast.
+        permutation[numUnneededDims + i] = nextBroadcastDim;
+        auto it = llvm::find(remainingDims, nextBroadcastDim);
+        assert(it != remainingDims.end() && "could not find broadcast dim");
+        remainingDims.erase(it);
+        nextBroadcastDim++;
+        continue;
+      }
+      // This transfer_read result dim is a dimension of the shape value. Look
+      // up its position in the broadcasted vector in the mapping.
+      int64_t shapedValueDim =
+          cast<AffineDimExpr>(readOp.getPermutationMap().getResult(i))
+              .getPosition();
+      permutation[numUnneededDims + i] = shapedDimToVecDim[shapedValueDim];
+      auto it = llvm::find(remainingDims, shapedDimToVecDim[shapedValueDim]);
+      assert(it != remainingDims.end() && "could not find regular dim");
+      remainingDims.erase(it);
+    }
+
+    // Fill up the dimensions at the beginning with all remaining dims.
+    assert(remainingDims.size() == numUnneededDims &&
+           "unexpected number of remaining dims");
+    for (int64_t i = 0; i < numUnneededDims; ++i) {
+      // All unneeded dims must be unit dimensions. Otherwise, the two transfer
+      // ops would be accessing different chunks.
+      assert(broadcastedShape[remainingDims[i]] == 1 && "expected unit dim");
+      permutation[i] = remainingDims[i];
+    }
+
+    // Build vector.transpose op.
+    //
+    //            unneeded dim (%d)    broadcast dims
+    //                       \          /   \
+    // %tp = vector.transpose %bc, [3, 5, 4, 0, 6, 1, 7, 2]
+    //     : vector<100x200x1x1x5x6x7x8xf32> to vector<1x6x5x100x7x200x8x1xf32>
+    Value transposed = rewriter.create<vector::TransposeOp>(
+        defWrite.getLoc(), broadcasted, permutation);
+
+    // 3. Remove unneeded dims.
+    //
+    // %1 = vector.extract %tp[0]
+    //     : vector<6x5x100x7x200x8x1xf32> from vector<1x6x5x100x7x200x8x1xf32>
+    SmallVector<int64_t> extractPositions(numUnneededDims, 0);
+    rewriter.replaceOpWithNewOp<vector::ExtractOp>(readOp, transposed,
+                                                   extractPositions);
+
     return success();
   }
 };
@@ -4159,7 +4263,7 @@ struct TransferReadAfterWriteToBroadcast
 
 void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
-  results.add<TransferReadAfterWriteToBroadcast>(context);
+  results.add<TransferReadAfterWrite>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a6b4f7f2717da81..308e0602ee46295 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2503,12 +2503,12 @@ func.func @fold_shape_cast_with_constant_mask() -> vector<4xi1>{
 
 // -----
 
-// TODO: This IR could be canonicalized but the canonicalization pattern is not
-// smart enough. For now, just make sure that we do not crash.
-
 // CHECK-LABEL: func.func @load_store_forwarding_rank_mismatch(
-//       CHECK:   vector.transfer_write
-//       CHECK:   vector.transfer_read
+//  CHECK-SAME:     %[[v0:.*]]: vector<4x1x1xf32>
+//       CHECK:   %[[bc:.*]] = vector.broadcast %[[v0]] : vector<4x1x1xf32> to vector<100x5x4x1x1xf32>
+//       CHECK:   %[[tp:.*]] = vector.transpose %[[bc]], [4, 3, 0, 2, 1] : vector<100x5x4x1x1xf32> to vector<1x1x100x4x5xf32>
+//       CHECK:   %[[extract:.*]] = vector.extract %[[tp]][0] : vector<1x100x4x5xf32> from vector<1x1x100x4x5xf32>
+//       CHECK:   return %[[extract]]
 func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: tensor<4x4x4xf32>) -> (vector<1x100x4x5xf32>) {
   %c0 = arith.constant 0 : index
   %cf0 = arith.constant 0.0 : f32



More information about the Mlir-commits mailing list