[Mlir-commits] [mlir] d3ddbe1 - [mlir][vector] Clarify vector.contract promotion behavior

Lei Zhang llvmlistbot at llvm.org
Mon Jan 30 16:08:34 PST 2023


Author: Lei Zhang
Date: 2023-01-31T00:08:26Z
New Revision: d3ddbe153e4ce1377653c8fb2936334bf9d105cf

URL: https://github.com/llvm/llvm-project/commit/d3ddbe153e4ce1377653c8fb2936334bf9d105cf
DIFF: https://github.com/llvm/llvm-project/commit/d3ddbe153e4ce1377653c8fb2936334bf9d105cf.diff

LOG: [mlir][vector] Clarify vector.contract promotion behavior

This commit updates vector.contract documentation to clarify
the promotion behavior if operands and the result have different
bitwidths. It also adds a check to disable signed/unsigned integer
types and only allow signless integers.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D142915

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index e0711a48b3004..3dc3123ccb8e3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -110,6 +110,11 @@ def Vector_ContractionOp :
     num_batch_dims (see dimension type descriptions below)). For K = 0 (no
     free or batch dimensions), the accumulator and output are a scalar.
 
+    If operands and the result have types of 
diff erent bitwidths, operands are
+    promoted to have the same bitwidth as the result before performing the
+    contraction. For integer types, only signless integer types are supported,
+    and the promotion happens via sign extension.
+
     Optional vector mask arguments (produced by CreateMaskOp or ConstantMaskOp)
     specify the dynamic dimension sizes of valid data within the lhs/rhs vector
     arguments.

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 974854e5637f5..cefd629127842 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -803,10 +803,15 @@ static LogicalResult verifyOutputShape(
 }
 
 LogicalResult ContractionOp::verify() {
-  auto lhsType = getLhsType();
-  auto rhsType = getRhsType();
-  auto accType = getAccType();
-  auto resType = getResultType();
+  VectorType lhsType = getLhsType();
+  VectorType rhsType = getRhsType();
+  Type accType = getAccType();
+  Type resType = getResultType();
+
+  if (lhsType.getElementType().isa<IntegerType>()) {
+    if (!lhsType.getElementType().isSignlessInteger())
+      return emitOpError("only supports signless integer types");
+  }
 
   // Verify that an indexing map was specified for each vector operand.
   if (getIndexingMapsArray().size() != 3)

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index e0e8ed35c77f4..5132fa8368996 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1652,3 +1652,14 @@ func.func @vector_scalable_extract_unaligned(%vec: vector<[16]xf32>) {
   // expected-error at +1 {{op failed to verify that position is a multiple of the result length.}}
   %0 = vector.scalable.extract %vec[5] : vector<4xf32> from vector<[16]xf32>
 }
+
+// -----
+
+func.func @integer_vector_contract(%arg0: vector<16x32xsi8>, %arg1: vector<32x16xsi8>, %arg2: vector<16x16xsi32>) -> vector<16x16xsi32> {
+  // expected-error at +1 {{op only supports signless integer types}}
+  %0 = vector.contract {
+    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>
+  } %arg0, %arg1, %arg2 : vector<16x32xsi8>, vector<32x16xsi8> into vector<16x16xsi32>
+  return %0: vector<16x16xsi32>
+}


        


More information about the Mlir-commits mailing list