[Mlir-commits] [mlir] f71f995 - [mlir][Vector] Modernize default lowering of vector transpose

Diego Caballero llvmlistbot at llvm.org
Thu Mar 10 14:35:20 PST 2022


Author: Diego Caballero
Date: 2022-03-10T22:33:14Z
New Revision: f71f9958b9845878909e005c67970e48b300f991

URL: https://github.com/llvm/llvm-project/commit/f71f9958b9845878909e005c67970e48b300f991
DIFF: https://github.com/llvm/llvm-project/commit/f71f9958b9845878909e005c67970e48b300f991.diff

LOG: [mlir][Vector] Modernize default lowering of vector transpose

This patch removes an old recursive implementation to lower vector.transpose to extract/insert operations
and replaces it with a iterative approach that leverages newer linearization/delinearization utilities.
The patch should be NFC except by the order in which the extract/insert ops are generated.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D121321

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/include/mlir/Dialect/Utils/IndexingUtils.h
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
    mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transpose-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index e53b25cc22f2c..4f991588d1d43 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -32,19 +32,6 @@ class LinalgDependenceGraph;
 /// `[0, permutation.size())`.
 bool isPermutation(ArrayRef<int64_t> permutation);
 
-/// Apply the permutation defined by `permutation` to `inVec`.
-/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
-/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
-/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
-template <typename T, unsigned N>
-void applyPermutationToVector(SmallVector<T, N> &inVec,
-                              ArrayRef<int64_t> permutation) {
-  SmallVector<T, N> auxVec(inVec.size());
-  for (const auto &en : enumerate(permutation))
-    auxVec[en.index()] = inVec[en.value()];
-  inVec = auxVec;
-}
-
 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
 /// the type of `source`.
 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);

diff  --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 4678ce0648150..3f2dd00c696f8 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -30,6 +30,19 @@ int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
 SmallVector<int64_t, 4> delinearize(ArrayRef<int64_t> strides,
                                     int64_t linearIndex);
 
+/// Apply the permutation defined by `permutation` to `inVec`.
+/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
+/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
+/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
+template <typename T, unsigned N>
+void applyPermutationToVector(SmallVector<T, N> &inVec,
+                              ArrayRef<int64_t> permutation) {
+  SmallVector<T, N> auxVec(inVec.size());
+  for (const auto &en : enumerate(permutation))
+    auxVec[en.index()] = inVec[en.value()];
+  inVec = auxVec;
+}
+
 /// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
 SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
                                        unsigned dropFront = 0,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 4297a83005fe5..1d46657018b39 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/Support/LLVM.h"

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index f1cd988d1fd74..907ffa3be4b95 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/AsmState.h"

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
index 452bf9d30ee4a..4ce38530fe1e9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/AffineExpr.h"

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 2e1418c529a25..4f863298ba422 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/Transforms/FoldUtils.h"

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 2b22412d6fc36..ab353326093a8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
@@ -300,16 +301,18 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
   }
 };
 
-/// Return the number of leftmost dimensions from the first rightmost transposed
-/// dimension found in 'transpose'.
-size_t getNumDimsFromFirstTransposedDim(ArrayRef<int64_t> transpose) {
+/// Given a 'transpose' pattern, prune the rightmost dimensions that are not
+/// transposed.
+void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
+                            SmallVectorImpl<int64_t> &result) {
   size_t numTransposedDims = transpose.size();
   for (size_t transpDim : llvm::reverse(transpose)) {
     if (transpDim != numTransposedDims - 1)
       break;
     numTransposedDims--;
   }
-  return numTransposedDims;
+
+  result.append(transpose.begin(), transpose.begin() + numTransposedDims);
 }
 
 /// Progressive lowering of TransposeOp.
@@ -334,6 +337,8 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
                                 PatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
 
+    Value input = op.vector();
+    VectorType inputType = op.getVectorType();
     VectorType resType = op.getResultType();
 
     // Set up convenience transposition table.
@@ -354,7 +359,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       Type flattenedType =
           VectorType::get(resType.getNumElements(), resType.getElementType());
       auto matrix =
-          rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector());
+          rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
       auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
       auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
       Value trans = rewriter.create<vector::FlatTransposeOp>(
@@ -365,54 +370,40 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
 
     // Generate unrolled extract/insert ops. We do not unroll the rightmost
     // (i.e., highest-order) dimensions that are not transposed and leave them
-    // in vector form to improve performance.
-    size_t numLeftmostTransposedDims = getNumDimsFromFirstTransposedDim(transp);
-
-    // The type of the extract operation will be scalar if all the dimensions
-    // are unrolled. Otherwise, it will be a vector with the shape of the
-    // dimensions that are not transposed.
-    Type extractType =
-        numLeftmostTransposedDims == transp.size()
-            ? resType.getElementType()
-            : VectorType::Builder(resType).setShape(
-                  resType.getShape().drop_front(numLeftmostTransposedDims));
-
+    // in vector form to improve performance. Therefore, we prune those
+    // dimensions from the shape/transpose data structures used to generate the
+    // extract/insert ops.
+    SmallVector<int64_t, 4> prunedTransp;
+    pruneNonTransposedDims(transp, prunedTransp);
+    size_t numPrunedDims = transp.size() - prunedTransp.size();
+    auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
+    SmallVector<int64_t, 4> ones(prunedInShape.size(), 1);
+    auto prunedInStrides = computeStrides(prunedInShape, ones);
+
+    // Generates the extract/insert operations for every scalar/vector element
+    // of the leftmost transposed dimensions. We traverse every transpose
+    // element using a linearized index that we delinearize to generate the
+    // appropriate indices for the extract/insert operations.
     Value result = rewriter.create<arith::ConstantOp>(
         loc, resType, rewriter.getZeroAttr(resType));
-    SmallVector<int64_t, 4> lhs(numLeftmostTransposedDims, 0);
-    SmallVector<int64_t, 4> rhs(numLeftmostTransposedDims, 0);
-    rewriter.replaceOp(op, expandIndices(loc, resType, extractType, 0,
-                                         numLeftmostTransposedDims, transp, lhs,
-                                         rhs, op.vector(), result, rewriter));
+    int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
+
+    for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
+         ++linearIdx) {
+      auto extractIdxs = delinearize(prunedInStrides, linearIdx);
+      SmallVector<int64_t, 4> insertIdxs(extractIdxs);
+      applyPermutationToVector(insertIdxs, prunedTransp);
+      Value extractOp =
+          rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
+      result =
+          rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
+    }
+
+    rewriter.replaceOp(op, result);
     return success();
   }
 
 private:
-  // Builds the indices arrays for the lhs and rhs. Generates the extract/insert
-  // operations when all the ranks go over the last dimension being transposed.
-  Value expandIndices(Location loc, VectorType resType, Type extractType,
-                      int64_t pos, int64_t numLeftmostTransposedDims,
-                      SmallVector<int64_t, 4> &transp,
-                      SmallVector<int64_t, 4> &lhs,
-                      SmallVector<int64_t, 4> &rhs, Value input, Value result,
-                      PatternRewriter &rewriter) const {
-    if (pos >= numLeftmostTransposedDims) {
-      auto ridx = rewriter.getI64ArrayAttr(rhs);
-      auto lidx = rewriter.getI64ArrayAttr(lhs);
-      Value e =
-          rewriter.create<vector::ExtractOp>(loc, extractType, input, ridx);
-      return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx);
-    }
-    for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) {
-      lhs[pos] = d;
-      rhs[transp[pos]] = d;
-      result = expandIndices(loc, resType, extractType, pos + 1,
-                             numLeftmostTransposedDims, transp, lhs, rhs, input,
-                             result, rewriter);
-    }
-    return result;
-  }
-
   /// Options to control the vector patterns.
   vector::VectorTransformsOptions vectorTransformOptions;
 };

diff  --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 4087eab4a5864..245e40cda8eea 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -8,14 +8,14 @@
 // ELTWISE:      %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
 // ELTWISE:      %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32>
 // ELTWISE:      %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32>
-// ELTWISE:      %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32>
-// ELTWISE:      %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32>
-// ELTWISE:      %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32>
-// ELTWISE:      %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32>
-// ELTWISE:      %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32>
-// ELTWISE:      %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32>
-// ELTWISE:      %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32>
-// ELTWISE:      %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32>
+// ELTWISE:      %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32>
+// ELTWISE:      %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : f32 into vector<3x2xf32>
+// ELTWISE:      %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32>
+// ELTWISE:      %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2, 0] : f32 into vector<3x2xf32>
+// ELTWISE:      %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32>
+// ELTWISE:      %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [0, 1] : f32 into vector<3x2xf32>
+// ELTWISE:      %[[T8:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32>
+// ELTWISE:      %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<3x2xf32>
 // ELTWISE:      %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32>
 // ELTWISE:      %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32>
 // ELTWISE:      return %[[T11]] : vector<3x2xf32>


        


More information about the Mlir-commits mailing list