[Mlir-commits] [mlir] [mlir][vector] Modernize `vector.transpose` op (PR #72594)

Matthias Springer llvmlistbot at llvm.org
Thu Nov 16 18:12:29 PST 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/72594

* Declare arguments/results with `let` statements.
* Rename `transp` to `permutation`.
* Change type of `transp` from `I64ArrayAttr` to `DenseI64ArrayAttr` (provides direct access to `ArrayRef<int64_t>` instead of `ArrayAttr`).

>From a4b57fdb77f637f82bb556d74f146138bbf1fab8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 17 Nov 2023 11:10:50 +0900
Subject: [PATCH] [mlir][vector] Modernize `vector.transpose` op

* Declare arguments/results with `let` statements.
* Rename `transp` to `permutation`.
* Change type of `transp` from `I64ArrayAttr` to `DenseI64ArrayAttr`.
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 17 +++--
 mlir/include/mlir/IR/AffineMap.h              |  2 +
 .../VectorToArmSME/VectorToArmSME.cpp         |  7 +-
 .../Conversion/VectorToGPU/VectorToGPU.cpp    |  7 +-
 .../Dialect/Arith/Transforms/IntNarrowing.cpp |  2 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 64 +++++++------------
 .../Transforms/LowerVectorTranspose.cpp       |  4 +-
 .../Vector/Transforms/VectorTransforms.cpp    | 20 +++---
 .../Vector/Transforms/VectorUnroll.cpp        |  3 +-
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp |  7 +-
 mlir/lib/IR/AffineMap.cpp                     |  6 ++
 11 files changed, 59 insertions(+), 80 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 02e80a6446dfb24..1397d4caf1d9d61 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2436,14 +2436,13 @@ def Vector_TransposeOp :
   Vector_Op<"transpose", [Pure,
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
     PredOpTrait<"operand and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
-    Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$transp)>,
-    Results<(outs AnyVectorOfAnyRank:$result)> {
+                 TCresVTEtIsSameAsOpBase<0, 0>>]> {
   let summary = "vector transpose operation";
   let description = [{
     Takes a n-D vector and returns the transposed n-D vector defined by
     the permutation of ranks in the n-sized integer array attribute (in case
     of 0-D vectors the array attribute must be empty).
+
     In the operation
 
     ```mlir
@@ -2452,7 +2451,7 @@ def Vector_TransposeOp :
       to vector<d_trans[0] x .. x d_trans[n-1] x f32>
     ```
 
-    the transp array [i_1, .., i_n] must be a permutation of [0, .., n-1].
+    the `permutation` array [i_1, .., i_n] must be a permutation of [0, .., n-1].
 
     Example:
 
@@ -2464,8 +2463,13 @@ def Vector_TransposeOp :
                           [c, f] ]
     ```
   }];
+
+  let arguments = (ins AnyVectorOfAnyRank:$vector,
+                       DenseI64ArrayAttr:$permutation);
+  let results = (outs AnyVectorOfAnyRank:$result);
+
   let builders = [
-    OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$transp)>
+    OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$permutation)>
   ];
   let extraClassDeclaration = [{
     VectorType getSourceVectorType() {
@@ -2474,10 +2478,9 @@ def Vector_TransposeOp :
     VectorType getResultVectorType() {
       return ::llvm::cast<VectorType>(getResult().getType());
     }
-    void getTransp(SmallVectorImpl<int64_t> &results);
   }];
   let assemblyFormat = [{
-    $vector `,` $transp attr-dict `:` type($vector) `to` type($result)
+    $vector `,` $permutation attr-dict `:` type($vector) `to` type($result)
   }];
   let hasCanonicalizer = 1;
   let hasFolder = 1;
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 713aef767edf669..981f3d392cbc98c 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -103,6 +103,8 @@ class AffineMap {
   /// (i.e. `[1,1,2]` is an invalid permutation).
   static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
                                      MLIRContext *context);
+  static AffineMap getPermutationMap(ArrayRef<int64_t> permutation,
+                                     MLIRContext *context);
 
   /// Returns an affine map with `numDims` input dimensions and results
   /// specified by `targets`.
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 953a465c18de69f..01c782676068d9a 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -436,12 +436,9 @@ struct TransposeOpToArmSMELowering
     if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
       return failure();
 
-    SmallVector<int64_t> transp;
-    for (auto attr : transposeOp.getTransp())
-      transp.push_back(cast<IntegerAttr>(attr).getInt());
-
     // Bail unless this is a true 2-D matrix transpose.
-    if (transp[0] != 1 || transp[1] != 0)
+    ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+    if (permutation[0] != 1 || permutation[1] != 0)
       return failure();
 
     OpBuilder::InsertionGuard g(rewriter);
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 1126c2c20758c7a..429d1137b6f3781 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -473,13 +473,8 @@ struct CombineTransferReadOpTranspose final
     if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
       return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
 
-    SmallVector<int64_t, 2> perm;
-    op.getTransp(perm);
-    SmallVector<unsigned, 2> permU;
-    for (int64_t o : perm)
-      permU.push_back(unsigned(o));
     AffineMap permutationMap =
-        AffineMap::getPermutationMap(permU, op.getContext());
+        AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
     AffineMap newMap =
         permutationMap.compose(transferReadOp.getPermutationMap());
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 1084fbc890053b9..79fabd6ed2e99a2 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -712,7 +712,7 @@ struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
     VectorType newTy =
         origTy.cloneWith(origTy.getShape(), ext->getInElementType());
     Value newTranspose = rewriter.create<vector::TransposeOp>(
-        op.getLoc(), newTy, ext->getIn(), op.getTransp());
+        op.getLoc(), newTy, ext->getIn(), op.getPermutation());
     ext->recreateAndReplace(rewriter, op, newTranspose);
     return success();
   }
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 95f49fa32bc0ae2..c7b74701fdbc8f2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1456,9 +1456,8 @@ LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
 
   if (!nextTransposeOp)
     return failure();
-  auto permutation = extractVector<unsigned>(nextTransposeOp.getTransp());
-  AffineMap m = inversePermutation(
-      AffineMap::getPermutationMap(permutation, extractOp.getContext()));
+  AffineMap m = inversePermutation(AffineMap::getPermutationMap(
+      nextTransposeOp.getPermutation(), extractOp.getContext()));
   extractPosition = applyPermutationMap(m, ArrayRef(extractPosition));
   return success();
 }
@@ -5376,20 +5375,20 @@ LogicalResult TypeCastOp::verify() {
 //===----------------------------------------------------------------------===//
 
 void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
-                                Value vector, ArrayRef<int64_t> transp) {
+                                Value vector, ArrayRef<int64_t> permutation) {
   VectorType vt = llvm::cast<VectorType>(vector.getType());
   SmallVector<int64_t, 4> transposedShape(vt.getRank());
   SmallVector<bool, 4> transposedScalableDims(vt.getRank());
-  for (unsigned i = 0; i < transp.size(); ++i) {
-    transposedShape[i] = vt.getShape()[transp[i]];
-    transposedScalableDims[i] = vt.getScalableDims()[transp[i]];
+  for (unsigned i = 0; i < permutation.size(); ++i) {
+    transposedShape[i] = vt.getShape()[permutation[i]];
+    transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
   }
 
   result.addOperands(vector);
   result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
                                   transposedScalableDims));
-  result.addAttribute(TransposeOp::getTranspAttrName(result.name),
-                      builder.getI64ArrayAttr(transp));
+  result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
+                      builder.getDenseI64ArrayAttr(permutation));
 }
 
 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
@@ -5401,13 +5400,12 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
 
   // Eliminate identity transpose ops. This happens when the dimensions of the
   // input vector remain in their original order after the transpose operation.
-  SmallVector<int64_t, 4> transp;
-  getTransp(transp);
+  ArrayRef<int64_t> perm = getPermutation();
 
   // Check if the permutation of the dimensions contains sequential values:
   // {0, 1, 2, ...}.
-  for (int64_t i = 0, e = transp.size(); i < e; i++) {
-    if (transp[i] != i)
+  for (int64_t i = 0, e = perm.size(); i < e; i++) {
+    if (perm[i] != i)
       return {};
   }
 
@@ -5421,20 +5419,19 @@ LogicalResult vector::TransposeOp::verify() {
   if (vectorType.getRank() != rank)
     return emitOpError("vector result rank mismatch: ") << rank;
   // Verify transposition array.
-  auto transpAttr = getTransp().getValue();
-  int64_t size = transpAttr.size();
+  ArrayRef<int64_t> perm = getPermutation();
+  int64_t size = perm.size();
   if (rank != size)
     return emitOpError("transposition length mismatch: ") << size;
   SmallVector<bool, 8> seen(rank, false);
-  for (const auto &ta : llvm::enumerate(transpAttr)) {
-    int64_t i = llvm::cast<IntegerAttr>(ta.value()).getInt();
-    if (i < 0 || i >= rank)
-      return emitOpError("transposition index out of range: ") << i;
-    if (seen[i])
-      return emitOpError("duplicate position index: ") << i;
-    seen[i] = true;
-    if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
-      return emitOpError("dimension size mismatch at: ") << i;
+  for (const auto &ta : llvm::enumerate(perm)) {
+    if (ta.value() < 0 || ta.value() >= rank)
+      return emitOpError("transposition index out of range: ") << ta.value();
+    if (seen[ta.value()])
+      return emitOpError("duplicate position index: ") << ta.value();
+    seen[ta.value()] = true;
+    if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
+      return emitOpError("dimension size mismatch at: ") << ta.value();
   }
   return success();
 }
@@ -5452,13 +5449,6 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
 
   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
                                 PatternRewriter &rewriter) const override {
-    // Wrapper around vector::TransposeOp::getTransp() for cleaner code.
-    auto getPermutation = [](vector::TransposeOp transpose) {
-      SmallVector<int64_t, 4> permutation;
-      transpose.getTransp(permutation);
-      return permutation;
-    };
-
     // Composes two permutations: result[i] = permutation1[permutation2[i]].
     auto composePermutations = [](ArrayRef<int64_t> permutation1,
                                   ArrayRef<int64_t> permutation2) {
@@ -5475,12 +5465,11 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
       return failure();
 
     SmallVector<int64_t, 4> permutation = composePermutations(
-        getPermutation(parentTransposeOp), getPermutation(transposeOp));
+        parentTransposeOp.getPermutation(), transposeOp.getPermutation());
     // Replace 'transposeOp' with a new transpose operation.
     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
         transposeOp, transposeOp.getResult().getType(),
-        parentTransposeOp.getVector(),
-        vector::getVectorSubscriptAttr(rewriter, permutation));
+        parentTransposeOp.getVector(), permutation);
     return success();
   }
 };
@@ -5539,8 +5528,7 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
 
     // Get the transpose permutation and apply it to the vector.create_mask or
     // vector.constant_mask operands.
-    SmallVector<int64_t> permutation;
-    transpOp.getTransp(permutation);
+    ArrayRef<int64_t> permutation = transpOp.getPermutation();
 
     if (createMaskOp) {
       auto maskOperands = createMaskOp.getOperands();
@@ -5572,10 +5560,6 @@ void vector::TransposeOp::getCanonicalizationPatterns(
               TransposeFolder, FoldTransposeSplat>(context);
 }
 
-void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
-  populateFromInt64AttrArray(getTransp(), results);
-}
-
 //===----------------------------------------------------------------------===//
 // ConstantMaskOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index dee786007c80630..97f6caca1b25ccc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -327,9 +327,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
     VectorType resType = op.getResultVectorType();
 
     // Set up convenience transposition table.
-    SmallVector<int64_t> transp;
-    for (auto attr : op.getTransp())
-      transp.push_back(cast<IntegerAttr>(attr).getInt());
+    ArrayRef<int64_t> transp = op.getPermutation();
 
     if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
         succeeded(isTranspose2DSlice(op)))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 713f9cb72c82cec..a20c8aeeb6f7108 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -212,8 +212,7 @@ struct CombineContractABTranspose final
       if (!transposeOp)
         continue;
       AffineMap permutationMap = AffineMap::getPermutationMap(
-          extractVector<unsigned>(transposeOp.getTransp()),
-          contractOp.getContext());
+          transposeOp.getPermutation(), contractOp.getContext());
       map = inversePermutation(permutationMap).compose(map);
       *operand = transposeOp.getVector();
       changed = true;
@@ -279,13 +278,13 @@ struct CombineContractResultTranspose final
 
     // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
     // To index into A in contract, we need revert(f)(g(C)) -> A.
-    auto accTMap = AffineMap::getPermutationMap(
-        extractVector<unsigned>(accTOp.getTransp()), context);
+    auto accTMap =
+        AffineMap::getPermutationMap(accTOp.getPermutation(), context);
 
     // Contract performs g(C) -> D. Result transpose performs h(D) -> E.
     // To index into E in contract, we need h(g(C)) -> E.
-    auto resTMap = AffineMap::getPermutationMap(
-        extractVector<unsigned>(resTOp.getTransp()), context);
+    auto resTMap =
+        AffineMap::getPermutationMap(resTOp.getPermutation(), context);
     auto combinedResMap = resTMap.compose(contractMap);
 
     // The accumulator and result share the same indexing map. So they should be
@@ -490,7 +489,7 @@ struct ReorderElementwiseOpsOnTranspose final
 
     // Make sure all operands are transpose/constant ops and collect their
     // transposition maps.
-    SmallVector<ArrayAttr> transposeMaps;
+    SmallVector<ArrayRef<int64_t>> transposeMaps;
     transposeMaps.reserve(op->getNumOperands());
     // Record the initial type before transposition. We'll use its shape later.
     // Any type will do here as we will check all transpose maps are the same.
@@ -498,7 +497,7 @@ struct ReorderElementwiseOpsOnTranspose final
     for (Value operand : op->getOperands()) {
       auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
       if (transposeOp) {
-        transposeMaps.push_back(transposeOp.getTransp());
+        transposeMaps.push_back(transposeOp.getPermutation());
         srcType = transposeOp.getSourceVectorType();
       } else if (!matchPattern(operand, m_Constant())) {
         return failure();
@@ -517,7 +516,7 @@ struct ReorderElementwiseOpsOnTranspose final
 
     // If there are constant operands, we need to insert inverse transposes for
     // them. Calculate the inverse order first.
-    auto order = extractVector<unsigned>(transposeMaps.front());
+    auto order = transposeMaps.front();
     SmallVector<int64_t> invOrder(order.size());
     for (int i = 0, e = order.size(); i < e; ++i)
       invOrder[order[i]] = i;
@@ -532,8 +531,7 @@ struct ReorderElementwiseOpsOnTranspose final
             srcType.getShape(),
             cast<VectorType>(operand.getType()).getElementType());
         srcValues.push_back(rewriter.create<vector::TransposeOp>(
-            operand.getLoc(), vectorType, operand,
-            rewriter.getI64ArrayAttr(invOrder)));
+            operand.getLoc(), vectorType, operand, invOrder));
       }
     }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 4cfac7de29ee76f..78b041255443c30 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -537,8 +537,7 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
     // Prepare the result vector;
     Value result = rewriter.create<arith::ConstantOp>(
         loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
-    SmallVector<int64_t> permutation;
-    transposeOp.getTransp(permutation);
+    ArrayRef<int64_t> permutation = transposeOp.getPermutation();
 
     // Unroll the computation.
     for (SmallVector<int64_t> elementOffsets :
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 467a521f9eada96..48cd67ad86c63fb 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -87,14 +87,11 @@ mlir::vector::isTranspose2DSlice(vector::TransposeOp op) {
   if (srcGtOneDims.size() != 2)
     return failure();
 
-  SmallVector<int64_t> transp;
-  for (auto attr : op.getTransp())
-    transp.push_back(cast<IntegerAttr>(attr).getInt());
-
   // Check whether the two source vector dimensions that are greater than one
   // must be transposed with each other so that we can apply one of the 2-D
   // transpose pattens. Otherwise, these patterns are not applicable.
-  if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp))
+  if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1],
+                                  op.getPermutation()))
     return failure();
 
   return std::pair<int, int>(srcGtOneDims[0], srcGtOneDims[1]);
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 93a8d048e0a61d5..80a26a595edee0a 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -236,6 +236,12 @@ AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
   assert(permutationMap.isPermutation() && "Invalid permutation vector");
   return permutationMap;
 }
+AffineMap AffineMap::getPermutationMap(ArrayRef<int64_t> permutation,
+                                       MLIRContext *context) {
+  SmallVector<unsigned> perm = llvm::map_to_vector(
+      permutation, [](int64_t i) { return static_cast<unsigned>(i); });
+  return AffineMap::getPermutationMap(perm, context);
+}
 
 AffineMap AffineMap::getMultiDimMapWithTargets(unsigned numDims,
                                                ArrayRef<unsigned> targets,



More information about the Mlir-commits mailing list