[Mlir-commits] [mlir] f5150ee - [mlir][linalg] Use affine apply in im2col gather index calculations
Quinn Dawkins
llvmlistbot at llvm.org
Fri Mar 24 09:27:48 PDT 2023
Author: Quinn Dawkins
Date: 2023-03-24T11:49:15-04:00
New Revision: f5150ee38eaeb9944aa837ac4096efb08f94d38b
URL: https://github.com/llvm/llvm-project/commit/f5150ee38eaeb9944aa837ac4096efb08f94d38b
DIFF: https://github.com/llvm/llvm-project/commit/f5150ee38eaeb9944aa837ac4096efb08f94d38b.diff
LOG: [mlir][linalg] Use affine apply in im2col gather index calculations
Differential Revision: https://reviews.llvm.org/D146816
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 58a23e2be54d1..491c533fc85e2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -41,37 +42,17 @@ static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) {
return builder.create<arith::MulFOp>(loc, x, y);
}
-// Unrolls the given composite `index` into a set of subindices with maximum
-// iteration ranges specified by `factors` according to the following
-// assumptions:
-// 1. The iteration range for `index` is [0, f1 * f2 * ... * fn] i.e. the
-// product of the given list of factors
-// 2. The iterators corresponding to the entries in `factors` are ordered from
-// slowest to fastest varying
-// Each subindex is then computed as:
-// subindex[i] = floor( (index % (fi * ... * fn)) / (fi-1 * ... * fn) )
-static SmallVector<Value, 3> unrollIndex(OpBuilder &b, Location loc,
- Value index,
- ArrayRef<int64_t> factors) {
+// Delinearizes the given composite `index` by the basis specified in `factors`.
+static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index,
+ ArrayRef<int64_t> factors) {
assert(factors.size() >= 1 && "empty factor list");
- SmallVector<Value, 3> indices(factors.size());
- int64_t runningProd = 1;
- for (int i = factors.size() - 1, end = 0; i >= end; i--) {
- Value unrolledIndex = index;
- if (i > 0) {
- Value modBase = b.create<arith::ConstantOp>(
- loc, b.getIndexAttr(runningProd * factors[i]));
- unrolledIndex = b.create<arith::RemUIOp>(loc, unrolledIndex, modBase);
- }
- if (runningProd > 1) {
- Value divDenom =
- b.create<arith::ConstantOp>(loc, b.getIndexAttr(runningProd));
- unrolledIndex = b.create<arith::DivUIOp>(loc, unrolledIndex, divDenom);
- }
- runningProd *= factors[i];
- indices[i] = unrolledIndex;
- }
- return indices;
+ SmallVector<Value> basis;
+ for (int64_t f : factors)
+ basis.push_back(b.create<arith::ConstantOp>(loc, b.getIndexAttr(f)));
+ FailureOr<SmallVector<Value>> multiIndex =
+ delinearizeIndex(b, loc, index, basis);
+ assert(!failed(multiIndex) && "Failed to linearize img2col index");
+ return *multiIndex;
}
// Given indices corresponding to iterators in the output (oIndex) and filter
@@ -79,9 +60,10 @@ static SmallVector<Value, 3> unrollIndex(OpBuilder &b, Location loc,
// input as `oIndex * stride + fIndex`.
static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
Value fIndex, int64_t stride) {
- Value strideVal = b.create<arith::ConstantOp>(loc, b.getIndexAttr(stride));
- Value convIndex = b.create<arith::MulIOp>(loc, oIndex, strideVal);
- return b.create<arith::AddIOp>(loc, convIndex, fIndex);
+ AffineExpr oExpr, fExpr;
+ bindSymbols(b.getContext(), oExpr, fExpr);
+ AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
+ return makeComposedAffineApply(b, loc, convMap, ValueRange{oIndex, fIndex});
}
FailureOr<std::pair<Operation *, Operation *>>
@@ -159,12 +141,12 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
// Recover the original iteration indices from the problem/input sizes.
- SmallVector<Value, 3> mIndices = unrollIndex(
+ SmallVector<Value> mIndices = unrollIndex(
nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
auto ohIndex = mIndices[0];
auto owIndex = mIndices[1];
- SmallVector<Value, 3> kIndices = unrollIndex(
+ SmallVector<Value> kIndices = unrollIndex(
nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
auto fhIndex = kIndices[0];
auto fwIndex = kIndices[1];
@@ -443,13 +425,13 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
Value nIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
// Recover the original iteration indices from the problem/input sizes.
- SmallVector<Value, 3> kIndices = unrollIndex(
+ SmallVector<Value> kIndices = unrollIndex(
nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
auto icIndex = kIndices[0];
auto fhIndex = kIndices[1];
auto fwIndex = kIndices[2];
- SmallVector<Value, 3> nIndices = unrollIndex(
+ SmallVector<Value> nIndices = unrollIndex(
nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
auto ohIndex = nIndices[0];
auto owIndex = nIndices[1];
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index ffcba1086f3f6..38c63490cf445 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -37,29 +37,12 @@ transform.sequence failures(propagate) {
// CHECK: %[[MINDEX:.+]] = linalg.index 1 : index
// CHECK: %[[KINDEX:.+]] = linalg.index 2 : index
-// Unrolled output shape indices.
-// CHECK: %[[C14:.+]] = arith.constant 14 : index
-// CHECK: %[[OWINDEX:.+]] = arith.remui %[[MINDEX]], %[[C14]] : index
-// CHECK: %[[C14_1:.+]] = arith.constant 14 : index
-// CHECK: %[[OHINDEX:.+]] = arith.divui %[[MINDEX]], %[[C14_1]] : index
+// Compute input channel/convolved indices.
+// CHECK: %[[ICINDEX:.+]] = affine.apply affine_map<(d0) -> (d0 mod 4)>(%[[KINDEX]])
+// CHECK: %[[CONVH:.+]] = affine.apply affine_map<(d0, d1) -> (d0 floordiv 14 + d1 floordiv 12)>(%[[MINDEX]], %[[KINDEX]])
+// CHECK: %[[CONVW:.+]] = affine.apply affine_map<(d0, d1) -> (d0 mod 14 + (d1 mod 12) floordiv 4)>(%[[MINDEX]], %[[KINDEX]])
-// Unrolled filter shape indices.
-// CHECK: %[[C4:.+]] = arith.constant 4 : index
-// CHECK: %[[ICINDEX:.+]] = arith.remui %[[KINDEX]], %[[C4]] : index
-// CHECK: %[[C12:.+]] = arith.constant 12 : index
-// CHECK: %[[FWREM:.+]] = arith.remui %[[KINDEX]], %[[C12]] : index
-// CHECK: %[[C4_2:.+]] = arith.constant 4 : index
-// CHECK: %[[FWINDEX:.+]] = arith.divui %[[FWREM]], %[[C4_2]] : index
-// CHECK: %[[C12_3:.+]] = arith.constant 12 : index
-// CHECK: %[[FHINDEX:.+]] = arith.divui %[[KINDEX]], %[[C12_3]] : index
-
-// Compute input indices.
-// CHECK: %[[SH:.+]] = arith.constant 1 : index
-// CHECK: %[[STRIDEDOH:.+]] = arith.muli %[[OHINDEX]], %[[SH]] : index
-// CHECK: %[[CONVH:.+]] = arith.addi %[[STRIDEDOH]], %[[FHINDEX]] : index
-// CHECK: %[[SW:.+]] = arith.constant 1 : index
-// CHECK: %[[STRIDEDOW:.+]] = arith.muli %[[OWINDEX]], %[[SW]] : index
-// CHECK: %[[CONVW:.+]] = arith.addi %[[STRIDEDOW]], %[[FWINDEX]] : index
+// Extract from the input tensor.
// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
// CHECK-SAME: %{{.+}}{{\[}}%[[BINDEX]], %[[CONVH]], %[[CONVW]], %[[ICINDEX]]] : tensor<1x16x16x4xf32>
// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
@@ -234,6 +217,13 @@ transform.sequence failures(propagate) {
// -----
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// Im2col maps
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 floordiv 9)>
+// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1) -> (d0 floordiv 14 + (d1 mod 9) floordiv 3)>
+// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1) -> (d0 + d1 - (d0 floordiv 14) * 14 - (d1 floordiv 3) * 3)>
+
+
// CHECK-DAG: #[[LHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
// CHECK-DAG: #[[RHSMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
// CHECK-DAG: #[[RESMAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
@@ -252,29 +242,12 @@ transform.sequence failures(propagate) {
// CHECK: %[[KINDEX:.+]] = linalg.index 1 : index
// CHECK: %[[NINDEX:.+]] = linalg.index 2 : index
-// Unrolled filter shape indices.
-// CHECK: %[[C3:.+]] = arith.constant 3 : index
-// CHECK: %[[FWINDEX:.+]] = arith.remui %[[KINDEX]], %[[C3]] : index
-// CHECK: %[[C9:.+]] = arith.constant 9 : index
-// CHECK: %[[FHREM:.+]] = arith.remui %[[KINDEX]], %[[C9]] : index
-// CHECK: %[[C3_1:.+]] = arith.constant 3 : index
-// CHECK: %[[FHINDEX:.+]] = arith.divui %[[FHREM]], %[[C3_1]] : index
-// CHECK: %[[C9_2:.+]] = arith.constant 9 : index
-// CHECK: %[[ICINDEX:.+]] = arith.divui %[[KINDEX]], %[[C9_2]] : index
-
-// Unrolled output shape indices.
-// CHECK: %[[C14:.+]] = arith.constant 14 : index
-// CHECK: %[[OWINDEX:.+]] = arith.remui %[[NINDEX]], %[[C14]] : index
-// CHECK: %[[C14_3:.+]] = arith.constant 14 : index
-// CHECK: %[[OHINDEX:.+]] = arith.divui %[[NINDEX]], %[[C14_3]] : index
+// Compute input channel/convolved indices.
+// CHECK: %[[ICINDEX:.+]] = affine.apply #[[MAP1]](%[[KINDEX]])
+// CHECK: %[[CONVH:.+]] = affine.apply #[[MAP7]](%[[NINDEX]], %[[KINDEX]])
+// CHECK: %[[CONVW:.+]] = affine.apply #[[MAP8]](%[[NINDEX]], %[[KINDEX]])
-// Compute input indices.
-// CHECK: %[[SH:.+]] = arith.constant 1 : index
-// CHECK: %[[STRIDEDOH:.+]] = arith.muli %[[OHINDEX]], %[[SH]] : index
-// CHECK: %[[CONVH:.+]] = arith.addi %[[STRIDEDOH]], %[[FHINDEX]] : index
-// CHECK: %[[SW:.+]] = arith.constant 1 : index
-// CHECK: %[[STRIDEDOW:.+]] = arith.muli %[[OWINDEX]], %[[SW]] : index
-// CHECK: %[[CONVW:.+]] = arith.addi %[[STRIDEDOW]], %[[FWINDEX]] : index
+// Extract from the input tensor.
// CHECK: %[[EXTRACTED_INPUT:.+]] = tensor.extract
// CHECK-SAME: %[[INPUT]]{{\[}}%[[BINDEX]], %[[ICINDEX]], %[[CONVH]], %[[CONVW]]] : tensor<8x4x16x16xf32>
// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
More information about the Mlir-commits
mailing list