[Mlir-commits] [mlir] [mlir][linalg] Vectorization support for convolution of i1 type (PR #109480)
Nirvedh Meshram
llvmlistbot at llvm.org
Fri Sep 20 14:15:34 PDT 2024
https://github.com/nirvedhmeshram created https://github.com/llvm/llvm-project/pull/109480
Normally convolutions present with the following linalg op region
```
^bb0(%arg14: i4, %arg15: i4, %arg16: i4):
%17 = arith.muli %arg14, %arg15 : i4
%18 = arith.addi %arg16, %17 : i4
linalg.yield %18 : i4
```
However, for i1 due to strength reduction we get something like
```
^bb0(%arg14: i1, %arg15: i1, %arg16: i1):
%17 = arith.andi %arg14, %arg15 : i1
%18 = arith.ori %arg16, %17 : i1
linalg.yield %18 : i1
```
This PR updates the logic to support this region for i1 types.
>From e5fc9412991ea5331c1d7891e131bc22fe0fccbe Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Wed, 18 Sep 2024 22:31:23 +0000
Subject: [PATCH] [mlir][linalg] Vectorization support for convolution of i1
type
---
.../Linalg/Transforms/Vectorization.cpp | 34 +++++++++++++------
.../Dialect/Linalg/vectorize-convolution.mlir | 29 ++++++++++++++++
2 files changed, 52 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a376afa5ddab12..1cdf937742fd2e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2947,12 +2947,14 @@ struct Conv1DGenerator
if (!setOperKind(reduceOp))
return;
- auto maybeKind = getCombinerOpKind(reduceOp);
- if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
+ maybeKind = getCombinerOpKind(reduceOp);
+ // Typically convolution will have a `Add` CombiningKind but for i1 type it
+ // can get strength reduced to `OR` which is also supported.
+ if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
+ *maybeKind != vector::CombiningKind::OR) &&
(oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
return;
}
-
auto rhsRank = rhsShapedType.getRank();
switch (oper) {
case Conv:
@@ -3156,9 +3158,9 @@ struct Conv1DGenerator
lhsVals[linearIndex(kw, w)],
rhsVals[kw], resVals[w]);
} else {
- resVals[w] = conv1dSliceAsContraction(rewriter, loc,
- lhsVals[linearIndex(kw, w)],
- rhsVals[kw], resVals[w]);
+ resVals[w] = conv1dSliceAsContraction(
+ rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw],
+ resVals[w], maybeKind);
}
break;
case Pool:
@@ -3226,18 +3228,24 @@ struct Conv1DGenerator
}
// Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
- Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
- Value lhs, Value rhs, Value res) {
+ Value
+ conv1dSliceAsContraction(RewriterBase &rewriter, Location loc, Value lhs,
+ Value rhs, Value res,
+ std::optional<vector::CombiningKind> maybeKind) {
vector::IteratorType par = vector::IteratorType::parallel;
vector::IteratorType red = vector::IteratorType::reduction;
AffineExpr n, w, f, c;
bindDims(ctx, n, w, f, c);
lhs = promote(rewriter, loc, lhs, res.getType());
rhs = promote(rewriter, loc, rhs, res.getType());
- return rewriter.create<vector::ContractionOp>(
+ auto ContrationOp = rewriter.create<vector::ContractionOp>(
loc, lhs, rhs, res,
/*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
/*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
+ if (maybeKind) {
+ ContrationOp.setKind(*maybeKind);
+ }
+ return ContrationOp;
}
// Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
@@ -3627,6 +3635,7 @@ struct Conv1DGenerator
int strideW, dilationW;
Value lhsShaped, rhsShaped, resShaped;
ShapedType lhsShapedType, rhsShapedType, resShapedType;
+ std::optional<vector::CombiningKind> maybeKind;
// Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
// Returns true iff it is a valid conv/pooling op.
@@ -3642,7 +3651,8 @@ struct Conv1DGenerator
switch (numBlockArguments) {
case 1: {
// Will be convolution if feeder is a MulOp.
- // Otherwise, if it can be pooling.
+ // A strength reduced version of MulOp for i1 type is AndOp which is also
+ // supported. Otherwise, it can be pooling.
auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
llvm::IsaPred<BlockArgument>);
Operation *feedOp = (*feedValIt).getDefiningOp();
@@ -3650,7 +3660,9 @@ struct Conv1DGenerator
oper = Pool;
isPoolExt = true;
poolExtOp = feedOp->getName().getIdentifier();
- } else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
+ } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
+ (isa<arith::AndIOp>(feedOp) &&
+ feedOp->getResultTypes()[0].isInteger(1))) &&
llvm::all_of(feedOp->getOperands(), [](Value v) {
if (isa<BlockArgument>(v))
return true;
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 93e36a69567bd5..84e790954b4d02 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -654,6 +654,35 @@ func.func @conv_1d_nwc_wcf_mixed_int_fp_memref(%input: memref<1x2x3xi8>, %filter
// CHECK: %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[CAST0]], %[[CAST1]], %[[READ2]]
// CHECK: vector.transfer_write %[[CONTRACT]], %arg2[%[[I0]], %[[I0]], %[[I0]]]
+// -----
+
+func.func @conv2d_i1_i1_i1(%arg0: tensor<1x8x6xi1>, %arg1: tensor<8x8x1xi1>, %arg2: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> {
+ %0 = linalg.conv_1d_ncw_fcw
+ {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
+ ins(%arg0, %arg1 : tensor<1x8x6xi1>, tensor<8x8x1xi1>)
+ outs(%arg2 : tensor<1x8x6xi1>) -> tensor<1x8x6xi1>
+ return %0 : tensor<1x8x6xi1>
+}
+
+// CHECK-LABEL: func @conv2d_i1_i1_i1
+// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: tensor<1x8x6xi1>, %[[FILTER:[0-9a-z]+]]: tensor<8x8x1xi1>, %[[OUTPUT:[0-9a-z]+]]: tensor<1x8x6xi1>) -> tensor<1x8x6xi1> {
+// CHECK-DAG: %[[I0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[FALSE:.+]] = arith.constant false
+// CHECK-DAG: %[[READ0:.+]] = vector.transfer_read %[[INPUT]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
+// CHECK-DAG: %[[READ1:.+]] = vector.transfer_read %[[FILTER]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
+// CHECK-DAG: %[[READ2:.+]] = vector.transfer_read %[[OUTPUT]][%[[I0]], %[[I0]], %[[I0]]], %[[FALSE]]
+// CHECK-DAG: %[[TREAD0:.+]] = vector.transpose %[[READ0]], [0, 2, 1] : vector<1x8x6xi1> to vector<1x6x8xi1>
+// CHECK-DAG: %[[TREAD1:.+]] = vector.transpose %[[READ1]], [2, 1, 0] : vector<8x8x1xi1> to vector<1x8x8xi1>
+// CHECK-DAG: %[[TREAD2:.+]] = vector.transpose %[[READ2]], [0, 2, 1] : vector<1x8x6xi1> to vector<1x6x8xi1>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[TREAD1]][0] : vector<8x8xi1> from vector<1x8x8xi1>
+// CHECK: %[[CONTRACT:.+]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<or>}
+// CHECK-SAME: %[[TREAD0]], %[[EXTRACT]], %[[TREAD2]] : vector<1x6x8xi1>, vector<8x8xi1> into vector<1x6x8xi1>
+// CHECK: %[[TCONTRACT:.+]] = vector.transpose %[[CONTRACT]], [0, 2, 1] : vector<1x6x8xi1> to vector<1x8x6xi1>
+// CHECK: %[[RESULT:.+]] = vector.transfer_write %[[TCONTRACT]], %[[OUTPUT]][%[[I0]], %[[I0]], %[[I0]]]
+// CHECK: return %[[RESULT]] : tensor<1x8x6xi1>
+
+
+
// -----
func.func @pooling_nwc_sum_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: memref<1xf32>, %output: memref<4x2x3xf32>) {
More information about the Mlir-commits
mailing list