[Mlir-commits] [mlir] 8d46bfa - [mlir] [VectorOps] A "reference" lowering of vector.transpose to LLVM IR

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 23 19:01:51 PDT 2020


Author: aartbik
Date: 2020-03-23T19:01:38-07:00
New Revision: 8d46bfa8084700f2b5c0cb2b668024290d9ed729

URL: https://github.com/llvm/llvm-project/commit/8d46bfa8084700f2b5c0cb2b668024290d9ed729
DIFF: https://github.com/llvm/llvm-project/commit/8d46bfa8084700f2b5c0cb2b668024290d9ed729.diff

LOG: [mlir] [VectorOps] A "reference" lowering of vector.transpose to LLVM IR

Summary: Makes the vector.tranpose runnable on CPU.

Reviewers: nicolasvasilache, andydavis1, rriddle

Reviewed By: nicolasvasilache

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D76644

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 50fa0150ba53..2a8835102d59 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -53,7 +53,9 @@ void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
 /// Collect a set of transformation patterns that are related to contracting
 /// or expanding vector operations:
 ///   ContractionOpLowering,
-///   ShapeCastOp2DDownCastRewritePattern, ShapeCastOp2DUpCastRewritePattern
+///   ShapeCastOp2DDownCastRewritePattern,
+///   ShapeCastOp2DUpCastRewritePattern
+///   TransposeOpLowering
 ///   OuterproductOpLowering
 /// These transformation express higher level vector ops in terms of more
 /// elementary extraction, insertion, reduction, product, and broadcast ops.

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 51d7962cdfc4..2e895e63ba27 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -88,7 +88,7 @@ def Vector_ContractionOp :
     iterator in the iterator type list, to each dimension of an N-D vector.
 
     Examples:
-
+    ```mlir
       // Simple dot product (K = 0).
       #contraction_accesses = [
        affine_map<(i) -> (i)>,
@@ -139,6 +139,7 @@ def Vector_ContractionOp :
 
       %5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
          : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
+    ```
   }];
   let builders = [OpBuilder<
     "Builder *builder, OperationState &result, Value lhs, Value rhs, "
@@ -203,7 +204,7 @@ def Vector_ReductionOp :
     http://llvm.org/docs/LangRef.html#experimental-vector-reduction-intrinsics
 
     Examples:
-    ```
+    ```mlir
       %1 = vector.reduction "add", %0 : vector<16xf32> into f32
 
       %3 = vector.reduction "xor", %2 : vector<4xi32> into i32
@@ -247,7 +248,7 @@ def Vector_BroadcastOp :
     shaped vector with the same element type is always legal.
 
     Examples:
-    ```
+    ```mlir
       %0 = constant 0.0 : f32
       %1 = vector.broadcast %0 : f32 to vector<16xf32>
       %2 = vector.broadcast %1 : vector<16xf32> to vector<4x16xf32>
@@ -290,7 +291,7 @@ def Vector_ShuffleOp :
       above, all mask values are in the range [0,s_1+t_1)
 
     Examples:
-    ```
+    ```mlir
     %0 = vector.shuffle %a, %b[0, 3]
                : vector<2xf32>, vector<2xf32>       ; yields vector<2xf32>
     %1 = vector.shuffle %c, %b[0, 1, 2]
@@ -332,7 +333,7 @@ def Vector_ExtractElementOp :
     https://llvm.org/docs/LangRef.html#extractelement-instruction
 
     Example:
-    ```
+    ```mlir
       %c = constant 15 : i32
       %1 = vector.extractelement %0[%c : i32]: vector<16xf32>
     ```
@@ -360,7 +361,7 @@ def Vector_ExtractOp :
     the proper position. Degenerates to an element type in the 0-D case.
 
     Examples:
-    ```
+    ```mlir
       %1 = vector.extract %0[3]: vector<4x8x16xf32>
       %2 = vector.extract %0[3, 3, 3]: vector<4x8x16xf32>
     ```
@@ -396,7 +397,7 @@ def Vector_ExtractSlicesOp :
     Currently, only unit strides are supported.
 
     Examples:
-    ```
+    ```mlir
       %0 = vector.transfer_read ...: vector<4x2xf32>
 
       %1 = vector.extract_slices %0, [2, 2], [1, 1]
@@ -448,8 +449,7 @@ def Vector_FMAOp :
     to the `llvm.fma.*` intrinsic.
 
     Example:
-
-    ```
+    ```mlir
       %3 = vector.fma %0, %1, %2: vector<8x16xf32>
     ```
   }];
@@ -483,7 +483,7 @@ def Vector_InsertElementOp :
     https://llvm.org/docs/LangRef.html#insertelement-instruction
 
     Example:
-    ```
+    ```mlir
       %c = constant 15 : i32
       %f = constant 0.0f : f32
       %1 = vector.insertelement %f, %0[%c : i32]: vector<16xf32>
@@ -516,7 +516,7 @@ def Vector_InsertOp :
     position. Degenerates to a scalar source type when n = 0.
 
     Examples:
-    ```
+    ```mlir
       %2 = vector.insert %0, %1[3]:
         vector<8x16xf32> into vector<4x8x16xf32>
       %5 = vector.insert %3, %4[3, 3, 3]:
@@ -559,7 +559,7 @@ def Vector_InsertSlicesOp :
     Currently, only unit strides are supported.
 
     Examples:
-    ```
+    ```mlir
       %0 = vector.extract_slices %0, [2, 2], [1, 1]
         : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
 
@@ -617,7 +617,7 @@ def Vector_InsertStridedSliceOp :
     the proper location as specified by the offsets.
 
     Examples:
-    ```
+    ```mlir
       %2 = vector.insert_strided_slice %0, %1
           {offsets = [0, 0, 2], strides = [1, 1]}:
         vector<2x4xf32> into vector<16x4x8xf32>
@@ -659,8 +659,7 @@ def Vector_OuterProductOp :
     lower to actual `fma` instructions on x86.
 
     Examples:
-
-    ```
+    ```mlir
       %2 = vector.outerproduct %0, %1: vector<4xf32>, vector<8xf32>
       return %2: vector<4x8xf32>
 
@@ -709,8 +708,8 @@ def Vector_ReshapeOp :
     In the examples below, valid data elements are represented by an alphabetic
     character, and undefined data elements are represented by '-'.
 
-    Example
-
+    Example:
+    ```mlir
       vector<1x8xf32> with valid data shape [6], fixed vector sizes [8]
 
                 input: [a, b, c, d, e, f]
@@ -719,8 +718,9 @@ def Vector_ReshapeOp :
 
         vector layout: [a, b, c, d, e, f, -, -]
 
-    Example
-
+    ```
+    Example:
+    ```mlir
       vector<2x8xf32> with valid data shape [10], fixed vector sizes [8]
 
                 input: [a, b, c, d, e, f, g, h, i, j]
@@ -729,9 +729,9 @@ def Vector_ReshapeOp :
 
         vector layout: [[a, b, c, d, e, f, g, h],
                         [i, j, -, -, -, -, -, -]]
-
-    Example
-
+    ```
+    Example:
+    ```mlir
       vector<2x2x2x3xf32> with valid data shape [3, 5], fixed vector sizes
       [2, 3]
 
@@ -750,9 +750,9 @@ def Vector_ReshapeOp :
                           [-, -, -]]
                          [[n, o, -],
                           [-, -, -]]]]
-
-    Example
-
+    ```
+    Example:
+    ```mlir
       %1 = vector.reshape %0, [%c3, %c6], [%c2, %c9], [4]
         : vector<3x2x4xf32> to vector<2x3x4xf32>
 
@@ -776,6 +776,7 @@ def Vector_ReshapeOp :
                        [[j, k, l, m],
                         [n, o, p, q],
                         [r, -, -, -]]]
+    ```
   }];
 
   let extraClassDeclaration = [{
@@ -828,7 +829,7 @@ def Vector_StridedSliceOp :
     `offsets` and ending at `offsets + sizes`.
 
     Examples:
-    ```
+    ```mlir
       %1 = vector.strided_slice %0
           {offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}:
         vector<4x8x16xf32> to vector<2x4x16xf32>
@@ -947,13 +948,12 @@ def Vector_TransferReadOp :
     implemented using a warp-shuffle if loop `j` were mapped to `threadIdx.x`.
 
     Syntax
-    ```
+    ```mlir
     operation ::= ssa-id `=` `vector.transfer_read` ssa-use-list
       `{` attribute-entry `} :` memref-type `,` vector-type
     ```
 
     Examples:
-
     ```mlir
     // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into vector<32x256xf32>
     // and pad with %f0 to handle the boundary case:
@@ -1028,7 +1028,7 @@ def Vector_TransferWriteOp :
 
     Syntax:
 
-    ```
+    ```mlir
     operation ::= `vector.transfer_write` ssa-use-list `{` attribute-entry `} :
       ` vector-type ', ' memref-type '
     ```
@@ -1139,7 +1139,7 @@ def Vector_TypeCastOp :
 
     Syntax:
 
-    ```
+    ```mlir
     operation ::= `vector.type_cast` ssa-use : memref-type to memref-type
     ```
 
@@ -1183,8 +1183,10 @@ def Vector_ConstantMaskOp :
     define a hyper-rectangular region within which elements values are set to 1
     (otherwise element values are set to 0).
 
-    Example: create a constant vector mask of size 4x3xi1 with elements in range
-             0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0).
+    Example:
+    ```
+      create a constant vector mask of size 4x3xi1 with elements in range
+      0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0).
 
       %1 = vector.constant_mask [3, 2] : vector<4x3xi1>
 
@@ -1196,6 +1198,7 @@ def Vector_ConstantMaskOp :
         rows  1 | 1    1    0
               2 | 1    1    0
               3 | 0    0    0
+    ```
   }];
 
   let extraClassDeclaration = [{
@@ -1217,8 +1220,10 @@ def Vector_CreateMaskOp :
     hyper-rectangular region within which elements values are set to 1
     (otherwise element values are set to 0).
 
-    Example: create a vector mask of size 4x3xi1 where elements in range
-             0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0).
+    Example:
+    ```
+      create a vector mask of size 4x3xi1 where elements in range
+      0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0).
 
       %1 = vector.create_mask %c3, %c2 : vector<4x3xi1>
 
@@ -1230,6 +1235,7 @@ def Vector_CreateMaskOp :
         rows  1 | 1    1    0
               2 | 1    1    0
               3 | 0    0    0
+    ```
   }];
 
   let hasCanonicalizer = 1;
@@ -1248,9 +1254,8 @@ def Vector_TupleOp :
     transformation and should be removed before lowering to lower-level
     dialects.
 
-
     Examples:
-    ```
+    ```mlir
       %0 = vector.transfer_read ... : vector<2x2xf32>
       %1 = vector.transfer_read ... : vector<2x1xf32>
       %2 = vector.transfer_read ... : vector<2x2xf32>
@@ -1280,20 +1285,21 @@ def Vector_TransposeOp :
     Takes a n-D vector and returns the transposed n-D vector defined by
     the permutation of ranks in the n-sized integer array attribute.
     In the operation
-
-    %1 = vector.tranpose %0, [i_1, .., i_n]
-      : vector<d_1 x .. x d_n x f32>
-      to vector<d_trans[0] x .. x d_trans[n-1] x f32>
-
+    ```mlir
+      %1 = vector.tranpose %0, [i_1, .., i_n]
+        : vector<d_1 x .. x d_n x f32>
+        to vector<d_trans[0] x .. x d_trans[n-1] x f32>
+    ```
     the transp array [i_1, .., i_n] must be a permutation of [0, .., n-1].
 
     Example:
-
+    ```mlir
     %1 = vector.tranpose %0, [1, 0] : vector<2x3xf32> to vector<3x2xf32>
 
      [ [a, b, c],       [ [a, d],
        [d, e, f] ]  ->    [b, e],
                           [c, f] ]
+    ```
   }];
   let extraClassDeclaration = [{
     VectorType getVectorType() {
@@ -1321,7 +1327,7 @@ def Vector_TupleGetOp :
     dialects.
 
     Examples:
-    ```
+    ```mlir
       %4 = vector.tuple %0, %1, %2, %3
         : vector<2x2xf32>, vector<2x1xf32>, vector<2x2xf32>, vector<2x1xf32>>
 
@@ -1351,7 +1357,7 @@ def Vector_PrintOp :
     format (for testing and debugging). No return value.
 
     Examples:
-    ```
+    ```mlir
       %0 = constant 0.0 : f32
       %1 = vector.broadcast %0 : f32 to vector<4xf32>
       vector.print %1 : vector<4xf32>
@@ -1414,7 +1420,7 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [NoSideEffect,
 
     Example:
 
-    ```
+    ```mlir
       %C = vector.matrix_multiply %A, %B
         { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
         (vector<64xf64>, vector<48xf64>) -> vector<12xf64>

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 6e54e5b05fb6..ef3484d31a3c 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -864,6 +864,67 @@ class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
   }
 };
 
+/// Progressive lowering of OuterProductOp.
+/// One:
+///   %x = vector.transpose %y, [1, 0]
+/// is replaced by:
+///   %z = constant dense<0.000000e+00>
+///   %0 = vector.extract %y[0, 0]
+///   %1 = vector.insert %0, %z [0, 0]
+///   ..
+///   %x = vector.insert .., .. [.., ..]
+class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
+public:
+  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+
+    VectorType resType = op.getResultType();
+    Type eltType = resType.getElementType();
+
+    // Set up convenience transposition table.
+    SmallVector<int64_t, 4> transp;
+    for (auto attr : op.transp())
+      transp.push_back(attr.cast<IntegerAttr>().getInt());
+
+    // Generate fully unrolled extract/insert ops.
+    Value zero = rewriter.create<ConstantOp>(loc, eltType,
+                                             rewriter.getZeroAttr(eltType));
+    Value result = rewriter.create<SplatOp>(loc, resType, zero);
+    SmallVector<int64_t, 4> lhs(transp.size(), 0);
+    SmallVector<int64_t, 4> rhs(transp.size(), 0);
+    rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs,
+                                         op.vector(), result, rewriter));
+    return success();
+  }
+
+private:
+  // Builds the indices arrays for the lhs and rhs. Generates the extract/insert
+  // operation when al ranks are exhausted.
+  Value expandIndices(Location loc, VectorType resType, int64_t pos,
+                      SmallVector<int64_t, 4> &transp,
+                      SmallVector<int64_t, 4> &lhs,
+                      SmallVector<int64_t, 4> &rhs, Value input, Value result,
+                      PatternRewriter &rewriter) const {
+    if (pos >= resType.getRank()) {
+      auto ridx = rewriter.getI64ArrayAttr(rhs);
+      auto lidx = rewriter.getI64ArrayAttr(lhs);
+      Type eltType = resType.getElementType();
+      Value e = rewriter.create<vector::ExtractOp>(loc, eltType, input, ridx);
+      return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx);
+    }
+    for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) {
+      lhs[pos] = d;
+      rhs[transp[pos]] = d;
+      result = expandIndices(loc, resType, pos + 1, transp, lhs, rhs, input,
+                             result, rewriter);
+    }
+    return result;
+  }
+};
+
 /// Progressive lowering of OuterProductOp.
 /// One:
 ///   %x = vector.outerproduct %lhs, %rhs, %acc
@@ -1353,7 +1414,7 @@ void mlir::vector::populateVectorContractLoweringPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context,
     VectorTransformsOptions parameters) {
   patterns.insert<ShapeCastOp2DDownCastRewritePattern,
-                  ShapeCastOp2DUpCastRewritePattern, OuterProductOpLowering>(
-      context);
+                  ShapeCastOp2DUpCastRewritePattern, TransposeOpLowering,
+                  OuterProductOpLowering>(context);
   patterns.insert<ContractionOpLowering>(parameters, context);
 }

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index bed90d6341d9..051c42d32ed5 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -296,6 +296,28 @@ func @outerproduct_acc(%arg0: vector<2xf32>,
   return %0: vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @transpose23
+// CHECK-SAME: %[[A:.*]]: vector<2x3xf32>
+// CHECK:      %[[Z:.*]] = constant dense<0.000000e+00> : vector<3x2xf32>
+// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32>
+// CHECK:      %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32>
+// CHECK:      %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32>
+// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32>
+// CHECK:      %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32>
+// CHECK:      %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32>
+// CHECK:      %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32>
+// CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32>
+// CHECK:      %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32>
+// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32>
+// CHECK:      %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32>
+// CHECK:      %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32>
+// CHECK:      return %[[T11]] : vector<3x2xf32>
+
+func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
+  %0 = vector.transpose %arg0, [1, 0] : vector<2x3xf32> to vector<3x2xf32>
+  return %0 : vector<3x2xf32>
+}
+
 // Shape up and downcasts for 2-D vectors, for supporting conversion to
 // llvm.matrix operations
 // CHECK-LABEL: func @shape_casts


        


More information about the Mlir-commits mailing list