[Mlir-commits] [mlir] 397336d - [mlir][vector] Add missing support for contract of integer lowering.

Thomas Raoux llvmlistbot at llvm.org
Tue Feb 16 07:17:02 PST 2021


Author: Thomas Raoux
Date: 2021-02-16T07:13:30-08:00
New Revision: 397336dcab81dd0bb95e50e95c737c3e77ee7985

URL: https://github.com/llvm/llvm-project/commit/397336dcab81dd0bb95e50e95c737c3e77ee7985
DIFF: https://github.com/llvm/llvm-project/commit/397336dcab81dd0bb95e50e95c737c3e77ee7985.diff

LOG: [mlir][vector] Add missing support for contract of integer lowering.

Some of the lowering of vector.contract didn't support integer case. Since
reduction of integer cannot accumulate we always break up the reduction op, it
should be merged by a separate canonicalization if possible.

Differential Revision: https://reviews.llvm.org/D96461

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 0a6c88d4d99b..200eb55076f7 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1716,6 +1716,24 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
 
 } // namespace
 
+/// Creates an AddIOp if `isInt` is true otherwise create an AddFOp using
+/// operands `x` and `y`.
+static Value createAdd(Location loc, Value x, Value y, bool isInt,
+                       PatternRewriter &rewriter) {
+  if (isInt)
+    return rewriter.create<AddIOp>(loc, x, y);
+  return rewriter.create<AddFOp>(loc, x, y);
+}
+
+/// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
+/// operands `x and `y`.
+static Value createMul(Location loc, Value x, Value y, bool isInt,
+                       PatternRewriter &rewriter) {
+  if (isInt)
+    return rewriter.create<MulIOp>(loc, x, y);
+  return rewriter.create<MulFOp>(loc, x, y);
+}
+
 namespace mlir {
 
 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
@@ -2003,13 +2021,14 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
   // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
   Value res =
       rewriter.create<ConstantOp>(loc, dstType, rewriter.getZeroAttr(dstType));
+  bool isInt = dstType.getElementType().isa<IntegerType>();
   for (unsigned r = 0; r < dstRows; ++r) {
     Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
     for (unsigned c = 0; c < dstColumns; ++c) {
       Value b = rank == 1
                     ? rhs
                     : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
-      Value m = rewriter.create<MulFOp>(op.getLoc(), a, b);
+      Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
       Value reduced = rewriter.create<vector::ReductionOp>(
           op.getLoc(), dstType.getElementType(), rewriter.getStringAttr("add"),
           m, ValueRange{});
@@ -2020,7 +2039,7 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
     }
   }
   if (auto acc = op.acc())
-    res = rewriter.create<AddFOp>(op.getLoc(), res, acc);
+    res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
   rewriter.replaceOp(op, res);
   return success();
 }
@@ -2176,6 +2195,7 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
   VectorType rhsType = op.getRhsType();
   Type resType = op.getResultType();
   assert(!resType.isa<VectorType>());
+  bool isInt = resType.isa<IntegerType>();
   // Use iterator index 0.
   int64_t iterIndex = 0;
   SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
@@ -2190,10 +2210,13 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
   // Base case.
   if (lhsType.getRank() == 1) {
     assert(rhsType.getRank() == 1 && "corrupt contraction");
-    Value m = rewriter.create<MulFOp>(loc, op.lhs(), op.rhs());
+    Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter);
     StringAttr kind = rewriter.getStringAttr("add");
-    return rewriter.create<vector::ReductionOp>(loc, resType, kind, m,
-                                                op.acc());
+    Value res = rewriter.create<vector::ReductionOp>(loc, resType, kind, m,
+                                                     ValueRange{});
+    if (auto acc = op.acc())
+      res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
+    return res;
   }
   // Construct new iterator types and affine map array attribute.
   std::array<AffineMap, 3> lowIndexingMaps = {

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 2c3ac0fe97bb..3adb18c1a2ae 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -18,8 +18,9 @@
 // CHECK-SAME: %[[B:.*1]]: vector<4xf32>,
 // CHECK-SAME: %[[C:.*2]]: f32
 // CHECK:      %[[F:.*]] = mulf %[[A]], %[[B]] : vector<4xf32>
-// CHECK:      %[[R:.*]] = vector.reduction "add", %[[F]], %[[C]] : vector<4xf32> into f32
-// CHECK:      return %[[R]] : f32
+// CHECK:      %[[R:.*]] = vector.reduction "add", %[[F]] : vector<4xf32> into f32
+// CHECK:      %[[ACC:.*]] = addf %[[R]], %[[C]] : f32
+// CHECK:      return %[[ACC]] : f32
 
 func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 {
   %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2
@@ -27,6 +28,21 @@ func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32)
   return %0 : f32
 }
 
+// CHECK-LABEL: func @extract_contract1_int
+// CHECK-SAME: %[[A:.*0]]: vector<4xi32>,
+// CHECK-SAME: %[[B:.*1]]: vector<4xi32>,
+// CHECK-SAME: %[[C:.*2]]: i32
+// CHECK:      %[[F:.*]] = muli %[[A]], %[[B]] : vector<4xi32>
+// CHECK:      %[[R:.*]] = vector.reduction "add", %[[F]] : vector<4xi32> into i32
+// CHECK:      %[[ACC:.*]] = addi %[[R]], %[[C]] : i32
+// CHECK:      return %[[ACC]] : i32
+
+func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %arg2: i32) -> i32 {
+  %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2
+    : vector<4xi32>, vector<4xi32> into i32
+  return %0 : i32
+}
+
 #matvec_accesses = [
   affine_map<(i, j) -> (i, j)>,
   affine_map<(i, j) -> (j)>,
@@ -61,6 +77,29 @@ func @extract_contract2(%arg0: vector<2x3xf32>,
   return %0 : vector<2xf32>
 }
 
+// CHECK-LABEL: func @extract_contract2_int
+// CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>,
+// CHECK-SAME: %[[B:.*1]]: vector<3xi32>,
+// CHECK-SAME: %[[C:.*2]]: vector<2xi32>
+// CHECK:      %[[R:.*]] = constant dense<0> : vector<2xi32>
+// CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xi32>
+// CHECK:      %[[T2:.*]] = muli %[[T0]], %[[B]] : vector<3xi32>
+// CHECK:      %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xi32> into i32
+// CHECK:      %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : i32 into vector<2xi32>
+// CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xi32>
+// CHECK:      %[[T7:.*]] = muli %[[T5]], %[[B]] : vector<3xi32>
+// CHECK:      %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xi32> into i32
+// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : i32 into vector<2xi32>
+// CHECK:      %[[T10:.*]] = addi %[[T9]], %[[C]] : vector<2xi32>
+// CHECK:      return %[[T10]] : vector<2xi32>
+func @extract_contract2_int(%arg0: vector<2x3xi32>,
+                        %arg1: vector<3xi32>,
+			%arg2: vector<2xi32>) -> vector<2xi32> {
+  %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2
+    : vector<2x3xi32>, vector<3xi32> into vector<2xi32>
+  return %0 : vector<2xi32>
+}
+
 #vecmat_accesses = [
   affine_map<(i, j) -> (j)>,
   affine_map<(i, j) -> (i, j)>,
@@ -162,12 +201,14 @@ func @extract_contract4(%arg0: vector<2x2xf32>,
 // CHECK:      %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
 // CHECK:      %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
 // CHECK:      %[[T2:.*]] = mulf %[[T0]], %[[T1]] : vector<3xf32>
-// CHECK:      %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[C]] : vector<3xf32> into f32
-// CHECK:      %[[T4:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
-// CHECK:      %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
-// CHECK:      %[[T6:.*]] = mulf %[[T4]], %[[T5]] : vector<3xf32>
-// CHECK:      %[[T7:.*]] = vector.reduction "add", %[[T6]], %[[T3]] : vector<3xf32> into f32
-// CHECK:      return %[[T7]] : f32
+// CHECK:      %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xf32> into f32
+// CHECK:      %[[T4:.*]] = addf %[[T3]], %[[C]] : f32
+// CHECK:      %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
+// CHECK:      %[[T6:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
+// CHECK:      %[[T7:.*]] = mulf %[[T5]], %[[T6]] : vector<3xf32>
+// CHECK:      %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xf32> into f32
+// CHECK:      %[[T9:.*]] = addf %[[T8]], %[[T4]] : f32
+// CHECK:      return %[[T9]] : f32
 
 func @full_contract1(%arg0: vector<2x3xf32>,
                      %arg1: vector<2x3xf32>,
@@ -200,7 +241,8 @@ func @full_contract1(%arg0: vector<2x3xf32>,
 // CHECK:      %[[T7:.*]] = vector.extract %[[B]][2, 0] : vector<3x2xf32>
 // CHECK:      %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32>
 // CHECK:      %[[T10:.*]] = mulf %[[T0]], %[[T9]] : vector<3xf32>
-// CHECK:      %[[T11:.*]] = vector.reduction "add", %[[T10]], %[[C]] : vector<3xf32> into f32
+// CHECK:      %[[T11:.*]] = vector.reduction "add", %[[T10]] : vector<3xf32> into f32
+// CHECK:      %[[ACC0:.*]] = addf %[[T11]], %[[C]] : f32
 //
 // CHECK:      %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
 // CHECK:      %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf
@@ -210,8 +252,9 @@ func @full_contract1(%arg0: vector<2x3xf32>,
 // CHECK:      %[[T19:.*]] = vector.extract %[[B]][2, 1] : vector<3x2xf32>
 // CHECK:      %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32>
 // CHECK:      %[[T22:.*]] = mulf %[[T12]], %[[T21]] : vector<3xf32>
-// CHECK:      %[[T23:.*]] = vector.reduction "add", %[[T22]], %[[T11]] : vector<3xf32> into f32
-// CHECK:      return %[[T23]] : f32
+// CHECK:      %[[T23:.*]] = vector.reduction "add", %[[T22]] : vector<3xf32> into f32
+// CHECK:      %[[ACC1:.*]] = addf %[[T23]], %[[ACC0]] : f32
+// CHECK:      return %[[ACC1]] : f32
 
 func @full_contract2(%arg0: vector<2x3xf32>,
                      %arg1: vector<3x2xf32>,


        


More information about the Mlir-commits mailing list