[Mlir-commits] [mlir] 917d95f - [mlir][Vector] Improve default lowering of vector transpose operations

Diego Caballero llvmlistbot at llvm.org
Mon Mar 7 09:57:56 PST 2022


Author: Diego Caballero
Date: 2022-03-07T17:56:02Z
New Revision: 917d95fc8adb2d68da442daee4e2508ada8c9bdc

URL: https://github.com/llvm/llvm-project/commit/917d95fc8adb2d68da442daee4e2508ada8c9bdc
DIFF: https://github.com/llvm/llvm-project/commit/917d95fc8adb2d68da442daee4e2508ada8c9bdc.diff

LOG: [mlir][Vector] Improve default lowering of vector transpose operations

The default lowering of vector transpose operations generates a large sequence of
scalar extract/insert operations, one pair for each scalar element in the input tensor.
In other words, the vector transpose is scalarized. However, there are transpose
patterns where one or more adjacent high-order dimensions are not transposed (for
example, in the transpose pattern [1, 0, 2, 3], dimensions 2 and 3 are not transposed).
This patch improves the lowering of those cases by not scalarizing them and extracting/
inserting a full n-D vector, where 'n' is the number of adjacent high-order dimensions
not being transposed. By doing so, we prevent the scalarization of the code and generate a
more performant vector version.

Paradoxically, this patch shouldn't improve the performance of transpose operations if
we are using LLVM. The LLVM pipeline is able to optimize away some of the extract/insert
operations and the SLP vectorizer is converting the scalar operations back to its vector
form. However, scalarizing a vector version of the code in MLIR and relying on the SLP
vectorizer to reconstruct the vector code again is highly undesirable for several reasons.

Reviewed By: nicolasvasilache, ThomasRaoux

Differential Revision: https://reviews.llvm.org/D120601

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transpose-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index f9413a7468187..4421b55fb40d7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -300,6 +300,18 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
   }
 };
 
+/// Return the number of leftmost dimensions from the first rightmost transposed
+/// dimension found in 'transpose'.
+size_t getNumDimsFromFirstTransposedDim(ArrayRef<int64_t> transpose) {
+  size_t numTransposedDims = transpose.size();
+  for (size_t transpDim : llvm::reverse(transpose)) {
+    if (transpDim != numTransposedDims - 1)
+      break;
+    numTransposedDims--;
+  }
+  return numTransposedDims;
+}
+
 /// Progressive lowering of TransposeOp.
 /// One:
 ///   %x = vector.transpose %y, [1, 0]
@@ -351,35 +363,51 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       return success();
     }
 
-    // Generate fully unrolled extract/insert ops.
+    // Generate unrolled extract/insert ops. We do not unroll the rightmost
+    // (i.e., highest-order) dimensions that are not transposed and leave them
+    // in vector form to improve performance.
+    size_t numLeftmostTransposedDims = getNumDimsFromFirstTransposedDim(transp);
+
+    // The type of the extract operation will be scalar if all the dimensions
+    // are unrolled. Otherwise, it will be a vector with the shape of the
+    // dimensions that are not transposed.
+    Type extractType =
+        numLeftmostTransposedDims == transp.size()
+            ? resType.getElementType()
+            : VectorType::Builder(resType).setShape(
+                  resType.getShape().drop_front(numLeftmostTransposedDims));
+
     Value result = rewriter.create<arith::ConstantOp>(
         loc, resType, rewriter.getZeroAttr(resType));
-    SmallVector<int64_t, 4> lhs(transp.size(), 0);
-    SmallVector<int64_t, 4> rhs(transp.size(), 0);
-    rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs,
-                                         op.vector(), result, rewriter));
+    SmallVector<int64_t, 4> lhs(numLeftmostTransposedDims, 0);
+    SmallVector<int64_t, 4> rhs(numLeftmostTransposedDims, 0);
+    rewriter.replaceOp(op, expandIndices(loc, resType, extractType, 0,
+                                         numLeftmostTransposedDims, transp, lhs,
+                                         rhs, op.vector(), result, rewriter));
     return success();
   }
 
 private:
   // Builds the indices arrays for the lhs and rhs. Generates the extract/insert
-  // operation when al ranks are exhausted.
-  Value expandIndices(Location loc, VectorType resType, int64_t pos,
+  // operations when all the ranks go over the last dimension being transposed.
+  Value expandIndices(Location loc, VectorType resType, Type extractType,
+                      int64_t pos, int64_t numLeftmostTransposedDims,
                       SmallVector<int64_t, 4> &transp,
                       SmallVector<int64_t, 4> &lhs,
                       SmallVector<int64_t, 4> &rhs, Value input, Value result,
                       PatternRewriter &rewriter) const {
-    if (pos >= resType.getRank()) {
+    if (pos >= numLeftmostTransposedDims) {
       auto ridx = rewriter.getI64ArrayAttr(rhs);
       auto lidx = rewriter.getI64ArrayAttr(lhs);
-      Type eltType = resType.getElementType();
-      Value e = rewriter.create<vector::ExtractOp>(loc, eltType, input, ridx);
+      Value e =
+          rewriter.create<vector::ExtractOp>(loc, extractType, input, ridx);
       return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx);
     }
     for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) {
       lhs[pos] = d;
       rhs[transp[pos]] = d;
-      result = expandIndices(loc, resType, pos + 1, transp, lhs, rhs, input,
+      result = expandIndices(loc, resType, extractType, pos + 1,
+                             numLeftmostTransposedDims, transp, lhs, rhs, input,
                              result, rewriter);
     }
     return result;

diff  --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 651006b07d2b7..4087eab4a5864 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -577,8 +577,25 @@ func @transpose021_8x8x1xf32(%arg0: vector<8x8x1xf32>) -> vector<8x1x8xf32> {
 
 // -----
 
+// ELTWISE-LABEL: func @transpose102_1x8x8xf32
 // AVX2-LABEL: func @transpose102_1x8x8
 func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> {
+  //      ELTWISE: vector.extract {{.*}}[0, 0] : vector<1x8x8xf32>
+  // ELTWISE-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<8x1x8xf32>
+  // ELTWISE-NEXT: vector.extract {{.*}}[0, 1] : vector<1x8x8xf32>
+  // ELTWISE-NEXT: vector.insert {{.*}} [1, 0] : vector<8xf32> into vector<8x1x8xf32>
+  // ELTWISE-NEXT: vector.extract {{.*}}[0, 2] : vector<1x8x8xf32>
+  // ELTWISE-NEXT: vector.insert {{.*}} [2, 0] : vector<8xf32> into vector<8x1x8xf32>
+  // ELTWISE-NEXT: vector.extract {{.*}}[0, 3] : vector<1x8x8xf32>
+  // ELTWISE-NEXT: vector.insert {{.*}} [3, 0] : vector<8xf32> into vector<8x1x8xf32>
+  // ELTWISE-NEXT: vector.extract {{.*}}[0, 4] : vector<1x8x8xf32>
+  // ELTWISE-NEXT: vector.insert {{.*}} [4, 0] : vector<8xf32> into vector<8x1x8xf32>
+  // ELTWISE-NEXT: vector.extract {{.*}}[0, 5] : vector<1x8x8xf32>
+  // ELTWISE-NEXT: vector.insert {{.*}} [5, 0] : vector<8xf32> into vector<8x1x8xf32>
+  // ELTWISE-NEXT: vector.extract {{.*}}[0, 6] : vector<1x8x8xf32>
+  // ELTWISE-NEXT: vector.insert {{.*}} [6, 0] : vector<8xf32> into vector<8x1x8xf32>
+  // ELTWISE-NEXT: vector.extract {{.*}}[0, 7] : vector<1x8x8xf32>
+  // ELTWISE-NEXT: vector.insert {{.*}} [7, 0] : vector<8xf32> into vector<8x1x8xf32>
   %0 = vector.transpose %arg0, [1, 0, 2] : vector<1x8x8xf32> to vector<8x1x8xf32>
   return %0 : vector<8x1x8xf32>
 }
@@ -587,8 +604,25 @@ func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> {
 
 // -----
 
+// ELTWISE-LABEL: func @transpose102_8x1x8xf32
 // AVX2-LABEL: func @transpose102_8x1x8
 func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> {
+  //      ELTWISE: vector.extract {{.*}}[0, 0] : vector<8x1x8xf32>
+  // ELTWISE-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<1x8x8xf32>
+  // ELTWISE-NEXT: vector.extract {{.*}}[1, 0] : vector<8x1x8xf32>
+  // ELTWISE-NEXT: vector.insert {{.*}} [0, 1] : vector<8xf32> into vector<1x8x8xf32>
+  // ELTWISE-NEXT: vector.extract {{.*}}[2, 0] : vector<8x1x8xf32>
+  // ELTWISE-NEXT: vector.insert {{.*}} [0, 2] : vector<8xf32> into vector<1x8x8xf32>
+  // ELTWISE-NEXT: vector.extract {{.*}}[3, 0] : vector<8x1x8xf32>
+  // ELTWISE-NEXT: vector.insert {{.*}} [0, 3] : vector<8xf32> into vector<1x8x8xf32>
+  // ELTWISE-NEXT: vector.extract {{.*}}[4, 0] : vector<8x1x8xf32>
+  // ELTWISE-NEXT: vector.insert {{.*}} [0, 4] : vector<8xf32> into vector<1x8x8xf32>
+  // ELTWISE-NEXT: vector.extract {{.*}}[5, 0] : vector<8x1x8xf32>
+  // ELTWISE-NEXT: vector.insert {{.*}} [0, 5] : vector<8xf32> into vector<1x8x8xf32>
+  // ELTWISE-NEXT: vector.extract {{.*}}[6, 0] : vector<8x1x8xf32>
+  // ELTWISE-NEXT: vector.insert {{.*}} [0, 6] : vector<8xf32> into vector<1x8x8xf32>
+  // ELTWISE-NEXT: vector.extract {{.*}}[7, 0] : vector<8x1x8xf32>
+  // ELTWISE-NEXT: vector.insert {{.*}} [0, 7] : vector<8xf32> into vector<1x8x8xf32>
   %0 = vector.transpose %arg0, [1, 0, 2] : vector<8x1x8xf32> to vector<1x8x8xf32>
   return %0 : vector<1x8x8xf32>
 }
@@ -597,6 +631,20 @@ func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> {
 
 // -----
 
+// ELTWISE-LABEL:   func @transpose1023_1x1x8x8xf32(
+// AVX2-LABEL: func @transpose1023_1x1x8x8
+func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8x8xf32> {
+  // Note the single 2-D extract/insert pair since 2 and 3 are not transposed!
+  //      ELTWISE: vector.extract {{.*}}[0, 0] : vector<1x1x8x8xf32>
+  // ELTWISE-NEXT: vector.insert {{.*}} [0, 0] : vector<8x8xf32> into vector<1x1x8x8xf32>
+  %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<1x1x8x8xf32> to vector<1x1x8x8xf32>
+  return %0 : vector<1x1x8x8xf32>
+}
+
+// AVX2-NOT: vector.shuffle
+
+// -----
+
 // AVX2-LABEL: func @transpose120_1x8x8
 func @transpose120_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x8x1xf32> {
 


        


More information about the Mlir-commits mailing list