[Mlir-commits] [mlir] 1e45b55 - [mlir] [VectorOps] Handle 'vector.shape_cast' lowering for all cases
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 9 16:08:59 PDT 2020
Author: aartbik
Date: 2020-06-09T16:08:45-07:00
New Revision: 1e45b55dcc8bc34a45e984ce2a5533a292775484
URL: https://github.com/llvm/llvm-project/commit/1e45b55dcc8bc34a45e984ce2a5533a292775484
DIFF: https://github.com/llvm/llvm-project/commit/1e45b55dcc8bc34a45e984ce2a5533a292775484.diff
LOG: [mlir] [VectorOps] Handle 'vector.shape_cast' lowering for all cases
Summary:
Even though this operation is intended for 1d/2d conversions currently,
leaving a semantic hole in the lowering prohibits proper testing of this
operation. This CL adds a straightforward reference implementation for the
missing cases.
Reviewers: nicolasvasilache, mehdi_amini, ftynse, reidtatge
Reviewed By: reidtatge
Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, msifontes
Tags: #mlir
Differential Revision: https://reviews.llvm.org/D81503
Added:
Modified:
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 99a3a951e9de..ebb0d1109bc9 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1466,6 +1466,61 @@ class ShapeCastOp2DUpCastRewritePattern
}
};
+// We typically should not lower general shape cast operations into data
+// movement instructions, since the assumption is that these casts are
+// optimized away during progressive lowering. For completeness, however,
+// we fall back to a reference implementation that moves all elements
+// into the right place if we get here.
+class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
+public:
+ using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto sourceVectorType = op.getSourceVectorType();
+ auto resultVectorType = op.getResultVectorType();
+ // Intended 2D/1D lowerings with better implementations.
+ int64_t srcRank = sourceVectorType.getRank();
+ int64_t resRank = resultVectorType.getRank();
+ if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
+ return failure();
+ // Compute number of elements involved in the reshape.
+ int64_t numElts = 1;
+ for (int64_t r = 0; r < srcRank; r++)
+ numElts *= sourceVectorType.getDimSize(r);
+ // Replace with data movement operations:
+ // x[0,0,0] = y[0,0]
+ // x[0,0,1] = y[0,1]
+ // x[0,1,0] = y[0,2]
+ // etc., incrementing the two index vectors "row-major"
+ // within the source and result shape.
+ SmallVector<int64_t, 4> srcIdx(srcRank);
+ SmallVector<int64_t, 4> resIdx(resRank);
+ Value result = rewriter.create<ConstantOp>(
+ loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
+ for (int64_t i = 0; i < numElts; i++) {
+ if (i != 0) {
+ incIdx(srcIdx, sourceVectorType, srcRank - 1);
+ incIdx(resIdx, resultVectorType, resRank - 1);
+ }
+ Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx);
+ result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+
+private:
+ static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
+ assert(0 <= r && r < tp.getRank());
+ if (++idx[r] == tp.getDimSize(r)) {
+ idx[r] = 0;
+ incIdx(idx, tp, r - 1);
+ }
+ }
+};
+
} // namespace
namespace mlir {
@@ -1864,7 +1919,8 @@ void mlir::vector::populateVectorContractLoweringPatterns(
ConstantMaskOpLowering,
OuterProductOpLowering,
ShapeCastOp2DDownCastRewritePattern,
- ShapeCastOp2DUpCastRewritePattern>(context);
+ ShapeCastOp2DUpCastRewritePattern,
+ ShapeCastOpRewritePattern>(context);
patterns.insert<TransposeOpLowering,
ContractionOpLowering,
ContractionOpToMatmulOpLowering,
diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index da784205224a..a0f5e66fea4b 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -319,7 +319,6 @@ func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
return %0 : vector<3x2xf32>
}
-
// CHECK-LABEL: func @nop_shape_cast
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
// CHECK: return %[[A]] : vector<16xf32>
@@ -378,6 +377,72 @@ func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
return %r0, %1 : vector<4xf32>, vector<2x2xf32>
}
+// CHECK-LABEL: func @shape_cast_2d2d
+// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
+// CHECK: %[[C:.*]] = constant dense<0.000000e+00> : vector<2x3xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<3x2xf32>
+// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : f32 into vector<2x3xf32>
+// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<3x2xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<3x2xf32>
+// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32>
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<3x2xf32>
+// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<3x2xf32>
+// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32>
+// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : vector<3x2xf32>
+// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32>
+// CHECK: return %[[T11]] : vector<2x3xf32>
+
+func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
+ %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32>
+ return %s : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func @shape_cast_3d1d
+// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
+// CHECK: %[[C:.*]] = constant dense<0.000000e+00> : vector<6xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0, 0] : vector<1x3x2xf32>
+// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0] : f32 into vector<6xf32>
+// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 0, 1] : vector<1x3x2xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : f32 into vector<6xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1, 0] : vector<1x3x2xf32>
+// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : f32 into vector<6xf32>
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][0, 1, 1] : vector<1x3x2xf32>
+// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [3] : f32 into vector<6xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2, 0] : vector<1x3x2xf32>
+// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [4] : f32 into vector<6xf32>
+// CHECK: %[[T10:.*]] = vector.extract %[[A]][0, 2, 1] : vector<1x3x2xf32>
+// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [5] : f32 into vector<6xf32>
+// CHECK: return %[[T11]] : vector<6xf32>
+
+func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
+ %s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32>
+ return %s : vector<6xf32>
+}
+
+// CHECK-LABEL: func @shape_cast_1d3d
+// CHECK-SAME: %[[A:.*]]: vector<6xf32>
+// CHECK: %[[C:.*]] = constant dense<0.000000e+00> : vector<2x1x3xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<6xf32>
+// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0, 0] : f32 into vector<2x1x3xf32>
+// CHECK: %[[T2:.*]] = vector.extract %[[A]][1] : vector<6xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 0, 1] : f32 into vector<2x1x3xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[A]][2] : vector<6xf32>
+// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 0, 2] : f32 into vector<2x1x3xf32>
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][3] : vector<6xf32>
+// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0, 0] : f32 into vector<2x1x3xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[A]][4] : vector<6xf32>
+// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 0, 1] : f32 into vector<2x1x3xf32>
+// CHECK: %[[T10:.*]] = vector.extract %[[A]][5] : vector<6xf32>
+// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 0, 2] : f32 into vector<2x1x3xf32>
+// CHECK: return %[[T11]] : vector<2x1x3xf32>
+
+func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
+ %s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32>
+ return %s : vector<2x1x3xf32>
+}
+
// MATRIX-LABEL: func @matmul
// MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
// MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
More information about the Mlir-commits
mailing list