[Mlir-commits] [mlir] 9032240 - [mlir][Vector] Allow lowering of vector.shape_cast 2D <-> 1D
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Mar 9 10:18:27 PDT 2020
Author: Nicolas Vasilache
Date: 2020-03-09T13:14:39-04:00
New Revision: 90322403c203ac180b31930f148555d13e03b121
URL: https://github.com/llvm/llvm-project/commit/90322403c203ac180b31930f148555d13e03b121
DIFF: https://github.com/llvm/llvm-project/commit/90322403c203ac180b31930f148555d13e03b121.diff
LOG: [mlir][Vector] Allow lowering of vector.shape_cast 2D <-> 1D
Summary:
This will support the progressive lowering of:
```
vector.contract ->
downcast + vector.matrix_multiply + upcast ->
llvm.intr.matrix
```
Differential Revision: https://reviews.llvm.org/D75776
Added:
Modified:
mlir/include/mlir/Dialect/VectorOps/VectorOps.td
mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index aee269555bd3..67a880d2e5d6 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -1091,7 +1091,12 @@ def Vector_ShapeCastOp :
described above is applied to each source/result tuple element pair.
It is currently assumed that this operation does not require moving data,
- and that it will be canonicalized away before lowering vector operations.
+ and that it will be folded away before lowering vector operations.
+
+ There is an exception to the folding expectation when targeting
+ llvm.intr.matrix operations. We need a type conversion back and forth from a
+ 2-D MLIR vector to a 1-D flattened LLVM vector.shape_cast lowering to LLVM
+ is supported in that particular case, for now.
Examples:
@@ -1108,6 +1113,14 @@ def Vector_ShapeCastOp :
tuple<vector<12x2xf32>, vector<9x2xf32>>
```
}];
+ let extraClassDeclaration = [{
+ VectorType getSourceVectorType() {
+ return source().getType().cast<VectorType>();
+ }
+ VectorType getResultVectorType() {
+ return getResult().getType().cast<VectorType>();
+ }
+ }];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
}
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index 8764d487dfb9..00089ebefd12 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -1171,6 +1171,75 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
}
};
+/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
+/// vectors progressively on the way to target llvm.matrix intrinsics.
+/// This iterates over the most major dimension of the 2-D vector and performs
+/// rewrites into:
+/// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
+class ShapeCastOp2DDownCastRewritePattern
+ : public OpRewritePattern<vector::ShapeCastOp> {
+public:
+ using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(vector::ShapeCastOp op,
+ PatternRewriter &rewriter) const override {
+ auto sourceVectorType = op.getSourceVectorType();
+ auto resultVectorType = op.getResultVectorType();
+ if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
+ return matchFailure();
+
+ auto loc = op.getLoc();
+ auto elemType = sourceVectorType.getElementType();
+ Value zero = rewriter.create<ConstantOp>(loc, elemType,
+ rewriter.getZeroAttr(elemType));
+ Value desc = rewriter.create<SplatOp>(loc, resultVectorType, zero);
+ unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
+ for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
+ Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i);
+ desc = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, vec, desc,
+ /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
+ }
+ rewriter.replaceOp(op, desc);
+ return matchSuccess();
+ }
+};
+
+/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
+/// vectors progressively on the way from targeting llvm.matrix intrinsics.
+/// This iterates over the most major dimension of the 2-D vector and performs
+/// rewrites into:
+/// vector.strided_slice from 1-D + vector.insert into 2-D
+class ShapeCastOp2DUpCastRewritePattern
+ : public OpRewritePattern<vector::ShapeCastOp> {
+public:
+ using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(vector::ShapeCastOp op,
+ PatternRewriter &rewriter) const override {
+ auto sourceVectorType = op.getSourceVectorType();
+ auto resultVectorType = op.getResultVectorType();
+ if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
+ return matchFailure();
+
+ auto loc = op.getLoc();
+ auto elemType = sourceVectorType.getElementType();
+ Value zero = rewriter.create<ConstantOp>(loc, elemType,
+ rewriter.getZeroAttr(elemType));
+ Value desc = rewriter.create<SplatOp>(loc, resultVectorType, zero);
+ unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
+ for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
+ Value vec = rewriter.create<vector::StridedSliceOp>(
+ loc, op.source(), /*offsets=*/i * mostMinorVectorSize,
+ /*sizes=*/mostMinorVectorSize,
+ /*strides=*/1);
+ desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
+ }
+ rewriter.replaceOp(op, desc);
+ return matchSuccess();
+ }
+};
+
} // namespace
// TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp).
@@ -1188,5 +1257,9 @@ void mlir::vector::populateVectorSlicesLoweringPatterns(
void mlir::vector::populateVectorContractLoweringPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<ContractionOpLowering>(context);
+ patterns.insert<ContractionOpLowering,
+ // Shape 2d up/down casts are used as part of contraction
+ // lowering.
+ ShapeCastOp2DDownCastRewritePattern,
+ ShapeCastOp2DUpCastRewritePattern>(context);
}
diff --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
index 275fd0841a60..c5e40a7c18ca 100644
--- a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
@@ -250,3 +250,42 @@ func @full_contract2(%arg0: vector<2x3xf32>,
: vector<2x3xf32>, vector<3x2xf32> into f32
return %0 : f32
}
+
+// Shape up and downcasts for 2-D vectors, for supporting conversion to
+// llvm.matrix operations
+// CHECK-LABEL: func @shape_casts
+func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
+ // CHECK: %[[cst:.*]] = constant dense<0.000000e+00> : vector<4xf32>
+ // CHECK: %[[cst22:.*]] = constant dense<0.000000e+00> : vector<2x2xf32>
+ // CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2x2xf32>
+ //
+ // CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[cst]]
+ // CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+ //
+ // CHECK: %[[ex1:.*]] = vector.extract %{{.*}}[1] : vector<2x2xf32>
+ //
+ // CHECK: %[[in2:.*]] = vector.insert_strided_slice %[[ex1]], %[[in0]]
+ // CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+ //
+ %0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32>
+ // CHECK: %[[add:.*]] = addf %[[in2]], %[[in2]] : vector<4xf32>
+ %r0 = addf %0, %0: vector<4xf32>
+ //
+ // CHECK: %[[ss0:.*]] = vector.strided_slice %[[add]]
+ // CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} :
+ // CHECK-SAME: vector<4xf32> to vector<2xf32>
+ //
+ // CHECK: %[[res0:.*]] = vector.insert %[[ss0]], %[[cst22]] [0] :
+ // CHECK-SAME: vector<2xf32> into vector<2x2xf32>
+ //
+ // CHECK: %[[s2:.*]] = vector.strided_slice %[[add]]
+ // CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} :
+ // CHECK-SAME: vector<4xf32> to vector<2xf32>
+ //
+ // CHECK: %[[res1:.*]] = vector.insert %[[s2]], %[[res0]] [1] :
+ // CHECK-SAME: vector<2xf32> into vector<2x2xf32>
+ //
+ %1 = vector.shape_cast %r0 : vector<4xf32> to vector<2x2xf32>
+ // CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32>
+ return %r0, %1 : vector<4xf32>, vector<2x2xf32>
+}
More information about the Mlir-commits
mailing list