[Mlir-commits] [mlir] 397336d - [mlir][vector] Add missing support for contract of integer lowering.
Thomas Raoux
llvmlistbot at llvm.org
Tue Feb 16 07:17:02 PST 2021
Author: Thomas Raoux
Date: 2021-02-16T07:13:30-08:00
New Revision: 397336dcab81dd0bb95e50e95c737c3e77ee7985
URL: https://github.com/llvm/llvm-project/commit/397336dcab81dd0bb95e50e95c737c3e77ee7985
DIFF: https://github.com/llvm/llvm-project/commit/397336dcab81dd0bb95e50e95c737c3e77ee7985.diff
LOG: [mlir][vector] Add missing support for contract of integer lowering.
Some of the lowering of vector.contract didn't support integer case. Since
reduction of integer cannot accumulate we always break up the reduction op, it
should be merged by a separate canonicalization if possible.
Differential Revision: https://reviews.llvm.org/D96461
Added:
Modified:
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 0a6c88d4d99b..200eb55076f7 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1716,6 +1716,24 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
} // namespace
+/// Creates an AddIOp if `isInt` is true otherwise create an AddFOp using
+/// operands `x` and `y`.
+static Value createAdd(Location loc, Value x, Value y, bool isInt,
+ PatternRewriter &rewriter) {
+ if (isInt)
+ return rewriter.create<AddIOp>(loc, x, y);
+ return rewriter.create<AddFOp>(loc, x, y);
+}
+
+/// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
+/// operands `x and `y`.
+static Value createMul(Location loc, Value x, Value y, bool isInt,
+ PatternRewriter &rewriter) {
+ if (isInt)
+ return rewriter.create<MulIOp>(loc, x, y);
+ return rewriter.create<MulFOp>(loc, x, y);
+}
+
namespace mlir {
/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
@@ -2003,13 +2021,14 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
// ExtractOp does not allow dynamic indexing, we must unroll explicitly.
Value res =
rewriter.create<ConstantOp>(loc, dstType, rewriter.getZeroAttr(dstType));
+ bool isInt = dstType.getElementType().isa<IntegerType>();
for (unsigned r = 0; r < dstRows; ++r) {
Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
for (unsigned c = 0; c < dstColumns; ++c) {
Value b = rank == 1
? rhs
: rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
- Value m = rewriter.create<MulFOp>(op.getLoc(), a, b);
+ Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
Value reduced = rewriter.create<vector::ReductionOp>(
op.getLoc(), dstType.getElementType(), rewriter.getStringAttr("add"),
m, ValueRange{});
@@ -2020,7 +2039,7 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
}
}
if (auto acc = op.acc())
- res = rewriter.create<AddFOp>(op.getLoc(), res, acc);
+ res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
rewriter.replaceOp(op, res);
return success();
}
@@ -2176,6 +2195,7 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
VectorType rhsType = op.getRhsType();
Type resType = op.getResultType();
assert(!resType.isa<VectorType>());
+ bool isInt = resType.isa<IntegerType>();
// Use iterator index 0.
int64_t iterIndex = 0;
SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
@@ -2190,10 +2210,13 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
// Base case.
if (lhsType.getRank() == 1) {
assert(rhsType.getRank() == 1 && "corrupt contraction");
- Value m = rewriter.create<MulFOp>(loc, op.lhs(), op.rhs());
+ Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter);
StringAttr kind = rewriter.getStringAttr("add");
- return rewriter.create<vector::ReductionOp>(loc, resType, kind, m,
- op.acc());
+ Value res = rewriter.create<vector::ReductionOp>(loc, resType, kind, m,
+ ValueRange{});
+ if (auto acc = op.acc())
+ res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
+ return res;
}
// Construct new iterator types and affine map array attribute.
std::array<AffineMap, 3> lowIndexingMaps = {
diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 2c3ac0fe97bb..3adb18c1a2ae 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -18,8 +18,9 @@
// CHECK-SAME: %[[B:.*1]]: vector<4xf32>,
// CHECK-SAME: %[[C:.*2]]: f32
// CHECK: %[[F:.*]] = mulf %[[A]], %[[B]] : vector<4xf32>
-// CHECK: %[[R:.*]] = vector.reduction "add", %[[F]], %[[C]] : vector<4xf32> into f32
-// CHECK: return %[[R]] : f32
+// CHECK: %[[R:.*]] = vector.reduction "add", %[[F]] : vector<4xf32> into f32
+// CHECK: %[[ACC:.*]] = addf %[[R]], %[[C]] : f32
+// CHECK: return %[[ACC]] : f32
func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 {
%0 = vector.contract #dotp_trait %arg0, %arg1, %arg2
@@ -27,6 +28,21 @@ func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32)
return %0 : f32
}
+// CHECK-LABEL: func @extract_contract1_int
+// CHECK-SAME: %[[A:.*0]]: vector<4xi32>,
+// CHECK-SAME: %[[B:.*1]]: vector<4xi32>,
+// CHECK-SAME: %[[C:.*2]]: i32
+// CHECK: %[[F:.*]] = muli %[[A]], %[[B]] : vector<4xi32>
+// CHECK: %[[R:.*]] = vector.reduction "add", %[[F]] : vector<4xi32> into i32
+// CHECK: %[[ACC:.*]] = addi %[[R]], %[[C]] : i32
+// CHECK: return %[[ACC]] : i32
+
+func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %arg2: i32) -> i32 {
+ %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2
+ : vector<4xi32>, vector<4xi32> into i32
+ return %0 : i32
+}
+
#matvec_accesses = [
affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (j)>,
@@ -61,6 +77,29 @@ func @extract_contract2(%arg0: vector<2x3xf32>,
return %0 : vector<2xf32>
}
+// CHECK-LABEL: func @extract_contract2_int
+// CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>,
+// CHECK-SAME: %[[B:.*1]]: vector<3xi32>,
+// CHECK-SAME: %[[C:.*2]]: vector<2xi32>
+// CHECK: %[[R:.*]] = constant dense<0> : vector<2xi32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xi32>
+// CHECK: %[[T2:.*]] = muli %[[T0]], %[[B]] : vector<3xi32>
+// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xi32> into i32
+// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : i32 into vector<2xi32>
+// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xi32>
+// CHECK: %[[T7:.*]] = muli %[[T5]], %[[B]] : vector<3xi32>
+// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xi32> into i32
+// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : i32 into vector<2xi32>
+// CHECK: %[[T10:.*]] = addi %[[T9]], %[[C]] : vector<2xi32>
+// CHECK: return %[[T10]] : vector<2xi32>
+func @extract_contract2_int(%arg0: vector<2x3xi32>,
+ %arg1: vector<3xi32>,
+ %arg2: vector<2xi32>) -> vector<2xi32> {
+ %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2
+ : vector<2x3xi32>, vector<3xi32> into vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
#vecmat_accesses = [
affine_map<(i, j) -> (j)>,
affine_map<(i, j) -> (i, j)>,
@@ -162,12 +201,14 @@ func @extract_contract4(%arg0: vector<2x2xf32>,
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
// CHECK: %[[T2:.*]] = mulf %[[T0]], %[[T1]] : vector<3xf32>
-// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[C]] : vector<3xf32> into f32
-// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
-// CHECK: %[[T6:.*]] = mulf %[[T4]], %[[T5]] : vector<3xf32>
-// CHECK: %[[T7:.*]] = vector.reduction "add", %[[T6]], %[[T3]] : vector<3xf32> into f32
-// CHECK: return %[[T7]] : f32
+// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xf32> into f32
+// CHECK: %[[T4:.*]] = addf %[[T3]], %[[C]] : f32
+// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
+// CHECK: %[[T6:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
+// CHECK: %[[T7:.*]] = mulf %[[T5]], %[[T6]] : vector<3xf32>
+// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xf32> into f32
+// CHECK: %[[T9:.*]] = addf %[[T8]], %[[T4]] : f32
+// CHECK: return %[[T9]] : f32
func @full_contract1(%arg0: vector<2x3xf32>,
%arg1: vector<2x3xf32>,
@@ -200,7 +241,8 @@ func @full_contract1(%arg0: vector<2x3xf32>,
// CHECK: %[[T7:.*]] = vector.extract %[[B]][2, 0] : vector<3x2xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32>
// CHECK: %[[T10:.*]] = mulf %[[T0]], %[[T9]] : vector<3xf32>
-// CHECK: %[[T11:.*]] = vector.reduction "add", %[[T10]], %[[C]] : vector<3xf32> into f32
+// CHECK: %[[T11:.*]] = vector.reduction "add", %[[T10]] : vector<3xf32> into f32
+// CHECK: %[[ACC0:.*]] = addf %[[T11]], %[[C]] : f32
//
// CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
// CHECK: %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf
@@ -210,8 +252,9 @@ func @full_contract1(%arg0: vector<2x3xf32>,
// CHECK: %[[T19:.*]] = vector.extract %[[B]][2, 1] : vector<3x2xf32>
// CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32>
// CHECK: %[[T22:.*]] = mulf %[[T12]], %[[T21]] : vector<3xf32>
-// CHECK: %[[T23:.*]] = vector.reduction "add", %[[T22]], %[[T11]] : vector<3xf32> into f32
-// CHECK: return %[[T23]] : f32
+// CHECK: %[[T23:.*]] = vector.reduction "add", %[[T22]] : vector<3xf32> into f32
+// CHECK: %[[ACC1:.*]] = addf %[[T23]], %[[ACC0]] : f32
+// CHECK: return %[[ACC1]] : f32
func @full_contract2(%arg0: vector<2x3xf32>,
%arg1: vector<3x2xf32>,
More information about the Mlir-commits
mailing list