[Mlir-commits] [mlir] d3ddbe1 - [mlir][vector] Clarify vector.contract promotion behavior
Lei Zhang
llvmlistbot at llvm.org
Mon Jan 30 16:08:34 PST 2023
Author: Lei Zhang
Date: 2023-01-31T00:08:26Z
New Revision: d3ddbe153e4ce1377653c8fb2936334bf9d105cf
URL: https://github.com/llvm/llvm-project/commit/d3ddbe153e4ce1377653c8fb2936334bf9d105cf
DIFF: https://github.com/llvm/llvm-project/commit/d3ddbe153e4ce1377653c8fb2936334bf9d105cf.diff
LOG: [mlir][vector] Clarify vector.contract promotion behavior
This commit updates vector.contract documentation to clarify
the promotion behavior if operands and the result have different
bitwidths. It also adds a check to disable signed/unsigned integer
types and only allow signless integers.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D142915
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index e0711a48b3004..3dc3123ccb8e3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -110,6 +110,11 @@ def Vector_ContractionOp :
num_batch_dims (see dimension type descriptions below)). For K = 0 (no
free or batch dimensions), the accumulator and output are a scalar.
+ If operands and the result have types of
diff erent bitwidths, operands are
+ promoted to have the same bitwidth as the result before performing the
+ contraction. For integer types, only signless integer types are supported,
+ and the promotion happens via sign extension.
+
Optional vector mask arguments (produced by CreateMaskOp or ConstantMaskOp)
specify the dynamic dimension sizes of valid data within the lhs/rhs vector
arguments.
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 974854e5637f5..cefd629127842 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -803,10 +803,15 @@ static LogicalResult verifyOutputShape(
}
LogicalResult ContractionOp::verify() {
- auto lhsType = getLhsType();
- auto rhsType = getRhsType();
- auto accType = getAccType();
- auto resType = getResultType();
+ VectorType lhsType = getLhsType();
+ VectorType rhsType = getRhsType();
+ Type accType = getAccType();
+ Type resType = getResultType();
+
+ if (lhsType.getElementType().isa<IntegerType>()) {
+ if (!lhsType.getElementType().isSignlessInteger())
+ return emitOpError("only supports signless integer types");
+ }
// Verify that an indexing map was specified for each vector operand.
if (getIndexingMapsArray().size() != 3)
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index e0e8ed35c77f4..5132fa8368996 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1652,3 +1652,14 @@ func.func @vector_scalable_extract_unaligned(%vec: vector<[16]xf32>) {
// expected-error at +1 {{op failed to verify that position is a multiple of the result length.}}
%0 = vector.scalable.extract %vec[5] : vector<4xf32> from vector<[16]xf32>
}
+
+// -----
+
+func.func @integer_vector_contract(%arg0: vector<16x32xsi8>, %arg1: vector<32x16xsi8>, %arg2: vector<16x16xsi32>) -> vector<16x16xsi32> {
+ // expected-error at +1 {{op only supports signless integer types}}
+ %0 = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>
+ } %arg0, %arg1, %arg2 : vector<16x32xsi8>, vector<32x16xsi8> into vector<16x16xsi32>
+ return %0: vector<16x16xsi32>
+}
More information about the Mlir-commits
mailing list