[Mlir-commits] [mlir] e4bc08f - [mlir] Allow vector.contract to have mixed types operands
Thomas Raoux
llvmlistbot at llvm.org
Fri Jun 19 17:20:03 PDT 2020
Author: Thomas Raoux
Date: 2020-06-19T17:08:57-07:00
New Revision: e4bc08f0121e4bf152592679a64af14c7cbfdba7
URL: https://github.com/llvm/llvm-project/commit/e4bc08f0121e4bf152592679a64af14c7cbfdba7
DIFF: https://github.com/llvm/llvm-project/commit/e4bc08f0121e4bf152592679a64af14c7cbfdba7.diff
LOG: [mlir] Allow vector.contract to have mixed types operands
Allow lhs and rhs to have different type than accumulator/destination. Some
hardware like GPUs support natively operations like uint8xuint8xuint32.
Differential Revision: https://reviews.llvm.org/D82069
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 9ae1c74df9e9..70ee272c8cef 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -40,12 +40,9 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
// with operators other than the current set: {*, +}.
def Vector_ContractionOp :
Vector_Op<"contract", [NoSideEffect,
- PredOpTrait<"first operand lhs and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 0>>,
- PredOpTrait<"second operand rhs and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 1>>,
+ PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
PredOpTrait<"third operand acc and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 1>>]>,
+ TCresVTEtIsSameAsOpBase<0, 2>>]>,
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
Variadic<VectorOf<[I1]>>:$masks,
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
@@ -140,6 +137,11 @@ def Vector_ContractionOp :
%5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
+
+ // Vector contraction with mixed typed. lhs/rhs have
diff erent element
+ // types than accumulator/result.
+ %6 = vector.contract #contraction_trait %0, %1, %2
+ : vector<10xf16>, vector<10xf16> into f32
```
}];
let builders = [OpBuilder<
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index f23ac10770ae..a9207230278f 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -28,6 +28,7 @@
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "llvm/Support/CommandLine.h"
@@ -1731,6 +1732,11 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
// TODO(ajcbik): implement masks.
if (llvm::size(op.masks()) != 0)
return failure();
+ // TODO(thomasraoux): support mixed mode contract lowering.
+ if (op.getLhsType().getElementType() !=
+ getElementTypeOrSelf(op.getAccType()) ||
+ op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
+ return failure();
// TODO(ntv, ajcbik): implement benefits, cost models.
MLIRContext *ctx = op.getContext();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 52d0586e98f2..84d596bf512f 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -760,7 +760,7 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
func @contraction(%arg0: vector<4x3xi32>,
%arg1: vector<3x7xf32>,
%arg2: vector<4x7xf32>) -> vector<4x7xf32> {
- // expected-error at +1 {{'vector.contract' op failed to verify that first operand lhs and result have same element type}}
+ // expected-error at +1 {{'vector.contract' op failed to verify that lhs and rhs have same element type}}
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
: vector<4x3xi32>, vector<3x7xf32> into vector<4x7xf32>
}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index dbffe4206f12..02ee4dd3883b 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -175,7 +175,7 @@ func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32
// CHECK-LABEL: @contraction
func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
%arg2 : vector<8x15x5xf32>, %arg3 : vector<8x8x15x5xf32>,
- %arg4 : index) {
+ %arg4 : vector<7x8x16x15xf16>, %arg5 : vector<8x16x7x5xf16>) {
// Test contraction with batch and contracting dims.
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
%0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2
@@ -193,6 +193,10 @@ func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
%2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask,
%rhs_mask
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
+ // Test contraction with mixed type.
+ // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32>
+ %3 = vector.contract #contraction_trait1 %arg4, %arg5, %arg3
+ : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32>
return
}
More information about the Mlir-commits
mailing list