[Mlir-commits] [mlir] 8d46bfa - [mlir] [VectorOps] A "reference" lowering of vector.transpose to LLVM IR
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 23 19:01:51 PDT 2020
Author: aartbik
Date: 2020-03-23T19:01:38-07:00
New Revision: 8d46bfa8084700f2b5c0cb2b668024290d9ed729
URL: https://github.com/llvm/llvm-project/commit/8d46bfa8084700f2b5c0cb2b668024290d9ed729
DIFF: https://github.com/llvm/llvm-project/commit/8d46bfa8084700f2b5c0cb2b668024290d9ed729.diff
LOG: [mlir] [VectorOps] A "reference" lowering of vector.transpose to LLVM IR
Summary: Makes the vector.tranpose runnable on CPU.
Reviewers: nicolasvasilache, andydavis1, rriddle
Reviewed By: nicolasvasilache
Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D76644
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 50fa0150ba53..2a8835102d59 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -53,7 +53,9 @@ void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
/// Collect a set of transformation patterns that are related to contracting
/// or expanding vector operations:
/// ContractionOpLowering,
-/// ShapeCastOp2DDownCastRewritePattern, ShapeCastOp2DUpCastRewritePattern
+/// ShapeCastOp2DDownCastRewritePattern,
+/// ShapeCastOp2DUpCastRewritePattern
+/// TransposeOpLowering
/// OuterproductOpLowering
/// These transformation express higher level vector ops in terms of more
/// elementary extraction, insertion, reduction, product, and broadcast ops.
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 51d7962cdfc4..2e895e63ba27 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -88,7 +88,7 @@ def Vector_ContractionOp :
iterator in the iterator type list, to each dimension of an N-D vector.
Examples:
-
+ ```mlir
// Simple dot product (K = 0).
#contraction_accesses = [
affine_map<(i) -> (i)>,
@@ -139,6 +139,7 @@ def Vector_ContractionOp :
%5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
+ ```
}];
let builders = [OpBuilder<
"Builder *builder, OperationState &result, Value lhs, Value rhs, "
@@ -203,7 +204,7 @@ def Vector_ReductionOp :
http://llvm.org/docs/LangRef.html#experimental-vector-reduction-intrinsics
Examples:
- ```
+ ```mlir
%1 = vector.reduction "add", %0 : vector<16xf32> into f32
%3 = vector.reduction "xor", %2 : vector<4xi32> into i32
@@ -247,7 +248,7 @@ def Vector_BroadcastOp :
shaped vector with the same element type is always legal.
Examples:
- ```
+ ```mlir
%0 = constant 0.0 : f32
%1 = vector.broadcast %0 : f32 to vector<16xf32>
%2 = vector.broadcast %1 : vector<16xf32> to vector<4x16xf32>
@@ -290,7 +291,7 @@ def Vector_ShuffleOp :
above, all mask values are in the range [0,s_1+t_1)
Examples:
- ```
+ ```mlir
%0 = vector.shuffle %a, %b[0, 3]
: vector<2xf32>, vector<2xf32> ; yields vector<2xf32>
%1 = vector.shuffle %c, %b[0, 1, 2]
@@ -332,7 +333,7 @@ def Vector_ExtractElementOp :
https://llvm.org/docs/LangRef.html#extractelement-instruction
Example:
- ```
+ ```mlir
%c = constant 15 : i32
%1 = vector.extractelement %0[%c : i32]: vector<16xf32>
```
@@ -360,7 +361,7 @@ def Vector_ExtractOp :
the proper position. Degenerates to an element type in the 0-D case.
Examples:
- ```
+ ```mlir
%1 = vector.extract %0[3]: vector<4x8x16xf32>
%2 = vector.extract %0[3, 3, 3]: vector<4x8x16xf32>
```
@@ -396,7 +397,7 @@ def Vector_ExtractSlicesOp :
Currently, only unit strides are supported.
Examples:
- ```
+ ```mlir
%0 = vector.transfer_read ...: vector<4x2xf32>
%1 = vector.extract_slices %0, [2, 2], [1, 1]
@@ -448,8 +449,7 @@ def Vector_FMAOp :
to the `llvm.fma.*` intrinsic.
Example:
-
- ```
+ ```mlir
%3 = vector.fma %0, %1, %2: vector<8x16xf32>
```
}];
@@ -483,7 +483,7 @@ def Vector_InsertElementOp :
https://llvm.org/docs/LangRef.html#insertelement-instruction
Example:
- ```
+ ```mlir
%c = constant 15 : i32
%f = constant 0.0f : f32
%1 = vector.insertelement %f, %0[%c : i32]: vector<16xf32>
@@ -516,7 +516,7 @@ def Vector_InsertOp :
position. Degenerates to a scalar source type when n = 0.
Examples:
- ```
+ ```mlir
%2 = vector.insert %0, %1[3]:
vector<8x16xf32> into vector<4x8x16xf32>
%5 = vector.insert %3, %4[3, 3, 3]:
@@ -559,7 +559,7 @@ def Vector_InsertSlicesOp :
Currently, only unit strides are supported.
Examples:
- ```
+ ```mlir
%0 = vector.extract_slices %0, [2, 2], [1, 1]
: vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
@@ -617,7 +617,7 @@ def Vector_InsertStridedSliceOp :
the proper location as specified by the offsets.
Examples:
- ```
+ ```mlir
%2 = vector.insert_strided_slice %0, %1
{offsets = [0, 0, 2], strides = [1, 1]}:
vector<2x4xf32> into vector<16x4x8xf32>
@@ -659,8 +659,7 @@ def Vector_OuterProductOp :
lower to actual `fma` instructions on x86.
Examples:
-
- ```
+ ```mlir
%2 = vector.outerproduct %0, %1: vector<4xf32>, vector<8xf32>
return %2: vector<4x8xf32>
@@ -709,8 +708,8 @@ def Vector_ReshapeOp :
In the examples below, valid data elements are represented by an alphabetic
character, and undefined data elements are represented by '-'.
- Example
-
+ Example:
+ ```mlir
vector<1x8xf32> with valid data shape [6], fixed vector sizes [8]
input: [a, b, c, d, e, f]
@@ -719,8 +718,9 @@ def Vector_ReshapeOp :
vector layout: [a, b, c, d, e, f, -, -]
- Example
-
+ ```
+ Example:
+ ```mlir
vector<2x8xf32> with valid data shape [10], fixed vector sizes [8]
input: [a, b, c, d, e, f, g, h, i, j]
@@ -729,9 +729,9 @@ def Vector_ReshapeOp :
vector layout: [[a, b, c, d, e, f, g, h],
[i, j, -, -, -, -, -, -]]
-
- Example
-
+ ```
+ Example:
+ ```mlir
vector<2x2x2x3xf32> with valid data shape [3, 5], fixed vector sizes
[2, 3]
@@ -750,9 +750,9 @@ def Vector_ReshapeOp :
[-, -, -]]
[[n, o, -],
[-, -, -]]]]
-
- Example
-
+ ```
+ Example:
+ ```mlir
%1 = vector.reshape %0, [%c3, %c6], [%c2, %c9], [4]
: vector<3x2x4xf32> to vector<2x3x4xf32>
@@ -776,6 +776,7 @@ def Vector_ReshapeOp :
[[j, k, l, m],
[n, o, p, q],
[r, -, -, -]]]
+ ```
}];
let extraClassDeclaration = [{
@@ -828,7 +829,7 @@ def Vector_StridedSliceOp :
`offsets` and ending at `offsets + sizes`.
Examples:
- ```
+ ```mlir
%1 = vector.strided_slice %0
{offsets = [0, 2], sizes = [2, 4], strides = [1, 1]}:
vector<4x8x16xf32> to vector<2x4x16xf32>
@@ -947,13 +948,12 @@ def Vector_TransferReadOp :
implemented using a warp-shuffle if loop `j` were mapped to `threadIdx.x`.
Syntax
- ```
+ ```mlir
operation ::= ssa-id `=` `vector.transfer_read` ssa-use-list
`{` attribute-entry `} :` memref-type `,` vector-type
```
Examples:
-
```mlir
// Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into vector<32x256xf32>
// and pad with %f0 to handle the boundary case:
@@ -1028,7 +1028,7 @@ def Vector_TransferWriteOp :
Syntax:
- ```
+ ```mlir
operation ::= `vector.transfer_write` ssa-use-list `{` attribute-entry `} :
` vector-type ', ' memref-type '
```
@@ -1139,7 +1139,7 @@ def Vector_TypeCastOp :
Syntax:
- ```
+ ```mlir
operation ::= `vector.type_cast` ssa-use : memref-type to memref-type
```
@@ -1183,8 +1183,10 @@ def Vector_ConstantMaskOp :
define a hyper-rectangular region within which elements values are set to 1
(otherwise element values are set to 0).
- Example: create a constant vector mask of size 4x3xi1 with elements in range
- 0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0).
+ Example:
+ ```
+ create a constant vector mask of size 4x3xi1 with elements in range
+ 0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0).
%1 = vector.constant_mask [3, 2] : vector<4x3xi1>
@@ -1196,6 +1198,7 @@ def Vector_ConstantMaskOp :
rows 1 | 1 1 0
2 | 1 1 0
3 | 0 0 0
+ ```
}];
let extraClassDeclaration = [{
@@ -1217,8 +1220,10 @@ def Vector_CreateMaskOp :
hyper-rectangular region within which elements values are set to 1
(otherwise element values are set to 0).
- Example: create a vector mask of size 4x3xi1 where elements in range
- 0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0).
+ Example:
+ ```
+ create a vector mask of size 4x3xi1 where elements in range
+ 0 <= row <= 2 and 0 <= col <= 1 are set to 1 (others to 0).
%1 = vector.create_mask %c3, %c2 : vector<4x3xi1>
@@ -1230,6 +1235,7 @@ def Vector_CreateMaskOp :
rows 1 | 1 1 0
2 | 1 1 0
3 | 0 0 0
+ ```
}];
let hasCanonicalizer = 1;
@@ -1248,9 +1254,8 @@ def Vector_TupleOp :
transformation and should be removed before lowering to lower-level
dialects.
-
Examples:
- ```
+ ```mlir
%0 = vector.transfer_read ... : vector<2x2xf32>
%1 = vector.transfer_read ... : vector<2x1xf32>
%2 = vector.transfer_read ... : vector<2x2xf32>
@@ -1280,20 +1285,21 @@ def Vector_TransposeOp :
Takes a n-D vector and returns the transposed n-D vector defined by
the permutation of ranks in the n-sized integer array attribute.
In the operation
-
- %1 = vector.tranpose %0, [i_1, .., i_n]
- : vector<d_1 x .. x d_n x f32>
- to vector<d_trans[0] x .. x d_trans[n-1] x f32>
-
+ ```mlir
+ %1 = vector.tranpose %0, [i_1, .., i_n]
+ : vector<d_1 x .. x d_n x f32>
+ to vector<d_trans[0] x .. x d_trans[n-1] x f32>
+ ```
the transp array [i_1, .., i_n] must be a permutation of [0, .., n-1].
Example:
-
+ ```mlir
%1 = vector.tranpose %0, [1, 0] : vector<2x3xf32> to vector<3x2xf32>
[ [a, b, c], [ [a, d],
[d, e, f] ] -> [b, e],
[c, f] ]
+ ```
}];
let extraClassDeclaration = [{
VectorType getVectorType() {
@@ -1321,7 +1327,7 @@ def Vector_TupleGetOp :
dialects.
Examples:
- ```
+ ```mlir
%4 = vector.tuple %0, %1, %2, %3
: vector<2x2xf32>, vector<2x1xf32>, vector<2x2xf32>, vector<2x1xf32>>
@@ -1351,7 +1357,7 @@ def Vector_PrintOp :
format (for testing and debugging). No return value.
Examples:
- ```
+ ```mlir
%0 = constant 0.0 : f32
%1 = vector.broadcast %0 : f32 to vector<4xf32>
vector.print %1 : vector<4xf32>
@@ -1414,7 +1420,7 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [NoSideEffect,
Example:
- ```
+ ```mlir
%C = vector.matrix_multiply %A, %B
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
(vector<64xf64>, vector<48xf64>) -> vector<12xf64>
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 6e54e5b05fb6..ef3484d31a3c 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -864,6 +864,67 @@ class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
}
};
+/// Progressive lowering of OuterProductOp.
+/// One:
+/// %x = vector.transpose %y, [1, 0]
+/// is replaced by:
+/// %z = constant dense<0.000000e+00>
+/// %0 = vector.extract %y[0, 0]
+/// %1 = vector.insert %0, %z [0, 0]
+/// ..
+/// %x = vector.insert .., .. [.., ..]
+class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
+public:
+ using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransposeOp op,
+ PatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+
+ VectorType resType = op.getResultType();
+ Type eltType = resType.getElementType();
+
+ // Set up convenience transposition table.
+ SmallVector<int64_t, 4> transp;
+ for (auto attr : op.transp())
+ transp.push_back(attr.cast<IntegerAttr>().getInt());
+
+ // Generate fully unrolled extract/insert ops.
+ Value zero = rewriter.create<ConstantOp>(loc, eltType,
+ rewriter.getZeroAttr(eltType));
+ Value result = rewriter.create<SplatOp>(loc, resType, zero);
+ SmallVector<int64_t, 4> lhs(transp.size(), 0);
+ SmallVector<int64_t, 4> rhs(transp.size(), 0);
+ rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs,
+ op.vector(), result, rewriter));
+ return success();
+ }
+
+private:
+ // Builds the indices arrays for the lhs and rhs. Generates the extract/insert
+ // operation when al ranks are exhausted.
+ Value expandIndices(Location loc, VectorType resType, int64_t pos,
+ SmallVector<int64_t, 4> &transp,
+ SmallVector<int64_t, 4> &lhs,
+ SmallVector<int64_t, 4> &rhs, Value input, Value result,
+ PatternRewriter &rewriter) const {
+ if (pos >= resType.getRank()) {
+ auto ridx = rewriter.getI64ArrayAttr(rhs);
+ auto lidx = rewriter.getI64ArrayAttr(lhs);
+ Type eltType = resType.getElementType();
+ Value e = rewriter.create<vector::ExtractOp>(loc, eltType, input, ridx);
+ return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx);
+ }
+ for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) {
+ lhs[pos] = d;
+ rhs[transp[pos]] = d;
+ result = expandIndices(loc, resType, pos + 1, transp, lhs, rhs, input,
+ result, rewriter);
+ }
+ return result;
+ }
+};
+
/// Progressive lowering of OuterProductOp.
/// One:
/// %x = vector.outerproduct %lhs, %rhs, %acc
@@ -1353,7 +1414,7 @@ void mlir::vector::populateVectorContractLoweringPatterns(
OwningRewritePatternList &patterns, MLIRContext *context,
VectorTransformsOptions parameters) {
patterns.insert<ShapeCastOp2DDownCastRewritePattern,
- ShapeCastOp2DUpCastRewritePattern, OuterProductOpLowering>(
- context);
+ ShapeCastOp2DUpCastRewritePattern, TransposeOpLowering,
+ OuterProductOpLowering>(context);
patterns.insert<ContractionOpLowering>(parameters, context);
}
diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index bed90d6341d9..051c42d32ed5 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -296,6 +296,28 @@ func @outerproduct_acc(%arg0: vector<2xf32>,
return %0: vector<2x3xf32>
}
+// CHECK-LABEL: func @transpose23
+// CHECK-SAME: %[[A:.*]]: vector<2x3xf32>
+// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3x2xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32>
+// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32>
+// CHECK: %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32>
+// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32>
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32>
+// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32>
+// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32>
+// CHECK: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32>
+// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32>
+// CHECK: return %[[T11]] : vector<3x2xf32>
+
+func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
+ %0 = vector.transpose %arg0, [1, 0] : vector<2x3xf32> to vector<3x2xf32>
+ return %0 : vector<3x2xf32>
+}
+
// Shape up and downcasts for 2-D vectors, for supporting conversion to
// llvm.matrix operations
// CHECK-LABEL: func @shape_casts
More information about the Mlir-commits
mailing list