[Mlir-commits] [mlir] cca4ac5 - [mlir][VectorOps] Lower vector.outerproduct of int vectors

Benjamin Kramer llvmlistbot at llvm.org
Tue Jul 7 05:42:43 PDT 2020


Author: Benjamin Kramer
Date: 2020-07-07T14:40:07+02:00
New Revision: cca4ac523e183b33be3d4c7da68d45b697a081bd

URL: https://github.com/llvm/llvm-project/commit/cca4ac523e183b33be3d4c7da68d45b697a081bd
DIFF: https://github.com/llvm/llvm-project/commit/cca4ac523e183b33be3d4c7da68d45b697a081bd.diff

LOG: [mlir][VectorOps] Lower vector.outerproduct of int vectors

vector.fma and mulf don't work on integers. Use a muli/addi pair or
plain muli instead.

Differential Revision: https://reviews.llvm.org/D83292

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 19c5bdcf97f2..ebad34fcd593 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1289,9 +1289,16 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
       Value m;
       if (acc) {
         Value e = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
-        m = rewriter.create<vector::FMAOp>(loc, b, op.rhs(), e);
+        if (eltType.isa<IntegerType>())
+          m = rewriter.create<AddIOp>(
+              loc, rewriter.create<MulIOp>(loc, b, op.rhs()), e);
+        else
+          m = rewriter.create<vector::FMAOp>(loc, b, op.rhs(), e);
       } else {
-        m = rewriter.create<MulFOp>(loc, b, op.rhs());
+        if (eltType.isa<IntegerType>())
+          m = rewriter.create<MulIOp>(loc, b, op.rhs());
+        else
+          m = rewriter.create<MulFOp>(loc, b, op.rhs());
       }
       result = rewriter.create<vector::InsertOp>(loc, resType, m, result, pos);
     }

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 0ebe049f5351..6a933d5e24b5 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -293,6 +293,50 @@ func @outerproduct_acc(%arg0: vector<2xf32>,
   return %0: vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @outerproduct_noacc_int
+// CHECK-SAME: %[[A:.*0]]: vector<2xi32>,
+// CHECK-SAME: %[[B:.*1]]: vector<3xi32>
+// CHECK:      %[[C0:.*]] = constant dense<0> : vector<2x3xi32>
+// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32>
+// CHECK:      %[[T1:.*]] = splat %[[T0]] : vector<3xi32>
+// CHECK:      %[[T2:.*]] = muli %[[T1]], %[[B]] : vector<3xi32>
+// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
+// CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xi32>
+// CHECK:      %[[T5:.*]] = splat %[[T4]] : vector<3xi32>
+// CHECK:      %[[T6:.*]] = muli %[[T5]], %[[B]] : vector<3xi32>
+// CHECK:      %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32>
+// CHECK:      return %[[T7]] : vector<2x3xi32>
+func @outerproduct_noacc_int(%arg0: vector<2xi32>,
+                             %arg1: vector<3xi32>) -> vector<2x3xi32> {
+  %0 = vector.outerproduct %arg0, %arg1 : vector<2xi32>, vector<3xi32>
+  return %0: vector<2x3xi32>
+}
+
+// CHECK-LABEL: func @outerproduct_acc_int
+// CHECK-SAME: %[[A:.*0]]: vector<2xi32>,
+// CHECK-SAME: %[[B:.*1]]: vector<3xi32>,
+// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32>
+// CHECK:      %[[C0:.*]] = constant dense<0> : vector<2x3xi32>
+// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32>
+// CHECK:      %[[T1:.*]] = splat %[[T0]] : vector<3xi32>
+// CHECK:      %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xi32>
+// CHECK:      %[[T3:.*]] = muli %[[T1]], %[[B]] : vector<3xi32>
+// CHECK:      %[[T4:.*]] = addi %[[T3]], %[[T2]] : vector<3xi32>
+// CHECK:      %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
+// CHECK:      %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xi32>
+// CHECK:      %[[T7:.*]] = splat %[[T6]] : vector<3xi32>
+// CHECK:      %[[T8:.*]] = vector.extract %[[C]][1] : vector<2x3xi32>
+// CHECK:      %[[T9:.*]] = muli %[[T7]], %[[B]] : vector<3xi32>
+// CHECK:      %[[T10:.*]] = addi %[[T9]], %[[T8]] : vector<3xi32>
+// CHECK:      %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3xi32> into vector<2x3xi32>
+// CHECK:      return %[[T11]] : vector<2x3xi32>
+func @outerproduct_acc_int(%arg0: vector<2xi32>,
+                           %arg1: vector<3xi32>,
+                           %arg2: vector<2x3xi32>) -> vector<2x3xi32> {
+  %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xi32>, vector<3xi32>
+  return %0: vector<2x3xi32>
+}
+
 // CHECK-LABEL: func @transpose23
 // CHECK-SAME: %[[A:.*]]: vector<2x3xf32>
 // CHECK:      %[[Z:.*]] = constant dense<0.000000e+00> : vector<3x2xf32>


        


More information about the Mlir-commits mailing list