[Mlir-commits] [mlir] e4bc08f - [mlir] Allow vector.contract to have mixed types operands

Thomas Raoux llvmlistbot at llvm.org
Fri Jun 19 17:20:03 PDT 2020


Author: Thomas Raoux
Date: 2020-06-19T17:08:57-07:00
New Revision: e4bc08f0121e4bf152592679a64af14c7cbfdba7

URL: https://github.com/llvm/llvm-project/commit/e4bc08f0121e4bf152592679a64af14c7cbfdba7
DIFF: https://github.com/llvm/llvm-project/commit/e4bc08f0121e4bf152592679a64af14c7cbfdba7.diff

LOG: [mlir] Allow vector.contract to have mixed types operands

Allow lhs and rhs to have different type than accumulator/destination. Some
hardware like GPUs support natively operations like uint8xuint8xuint32.

Differential Revision: https://reviews.llvm.org/D82069

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 9ae1c74df9e9..70ee272c8cef 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -40,12 +40,9 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
 // with operators other than the current set: {*, +}.
 def Vector_ContractionOp :
   Vector_Op<"contract", [NoSideEffect,
-     PredOpTrait<"first operand lhs and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 0>>,
-     PredOpTrait<"second operand rhs and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 1>>,
+     PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
      PredOpTrait<"third operand acc and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 1>>]>,
+                 TCresVTEtIsSameAsOpBase<0, 2>>]>,
     Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
                Variadic<VectorOf<[I1]>>:$masks,
                AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
@@ -140,6 +137,11 @@ def Vector_ContractionOp :
 
     %5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
        : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
+
+    // Vector contraction with mixed typed. lhs/rhs have 
diff erent element
+    // types than accumulator/result.
+    %6 = vector.contract #contraction_trait %0, %1, %2
+      : vector<10xf16>, vector<10xf16> into f32
     ```
   }];
   let builders = [OpBuilder<

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index f23ac10770ae..a9207230278f 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -28,6 +28,7 @@
 #include "mlir/IR/Module.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Types.h"
 
 #include "llvm/Support/CommandLine.h"
@@ -1731,6 +1732,11 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
   // TODO(ajcbik): implement masks.
   if (llvm::size(op.masks()) != 0)
     return failure();
+  // TODO(thomasraoux): support mixed mode contract lowering.
+  if (op.getLhsType().getElementType() !=
+          getElementTypeOrSelf(op.getAccType()) ||
+      op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
+    return failure();
 
   // TODO(ntv, ajcbik): implement benefits, cost models.
   MLIRContext *ctx = op.getContext();

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 52d0586e98f2..84d596bf512f 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -760,7 +760,7 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
 func @contraction(%arg0: vector<4x3xi32>,
                   %arg1: vector<3x7xf32>,
                   %arg2: vector<4x7xf32>) -> vector<4x7xf32> {
-  // expected-error at +1 {{'vector.contract' op failed to verify that first operand lhs and result have same element type}}
+  // expected-error at +1 {{'vector.contract' op failed to verify that lhs and rhs have same element type}}
   %0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
     : vector<4x3xi32>, vector<3x7xf32> into vector<4x7xf32>
 }

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index dbffe4206f12..02ee4dd3883b 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -175,7 +175,7 @@ func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32
 // CHECK-LABEL: @contraction
 func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
                   %arg2 : vector<8x15x5xf32>, %arg3 : vector<8x8x15x5xf32>,
-                  %arg4 : index) {
+                  %arg4 : vector<7x8x16x15xf16>, %arg5 : vector<8x16x7x5xf16>) {
   // Test contraction with batch and contracting dims.
   // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
   %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2
@@ -193,6 +193,10 @@ func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
   %2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask,
                                            %rhs_mask
       : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
+  // Test contraction with mixed type.
+  // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32>
+  %3 = vector.contract #contraction_trait1 %arg4, %arg5, %arg3
+      : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32>
   return
 }
 


        


More information about the Mlir-commits mailing list