[Mlir-commits] [mlir] c7b1176 - [mlir][linalg] Make Linalg vectorizer lower affine.apply
Andrzej Warzynski
llvmlistbot at llvm.org
Fri Jan 27 00:31:05 PST 2023
Author: Andrzej Warzynski
Date: 2023-01-27T08:30:50Z
New Revision: c7b1176e9afbfcc3da9482abbf7c1eb8793ff254
URL: https://github.com/llvm/llvm-project/commit/c7b1176e9afbfcc3da9482abbf7c1eb8793ff254
DIFF: https://github.com/llvm/llvm-project/commit/c7b1176e9afbfcc3da9482abbf7c1eb8793ff254.diff
LOG: [mlir][linalg] Make Linalg vectorizer lower affine.apply
It is possible that the input to the Linalg vectorizer contains
`affine.apply` ops (see the example in [1]). Such operations are not
vectarizable at the moment, but this can be fixed by simply converting
them to arithmetic operations. This is basically what this patch
introduces.
The IR change enabled in this patch could be part of a larger set of
"linalgOp pre-processing" transformations that happens right before
vectorization starts but after we know we can vectorize the op. I am
leaving this as a TODO.
[1] https://github.com/iree-org/iree/issues/10876.
Differential Revision: https://reviews.llvm.org/D142371
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index d81496ed0f911..b356de9bbdf86 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -9,6 +9,7 @@
// This file implements the linalg dialect Vectorization transformations.
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -1048,6 +1049,21 @@ mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
return success();
}
+/// Converts affine.apply Ops to arithmetic operations.
+static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
+ auto &newIP = linalgOp.getBlock()->front();
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointAfter(&newIP);
+ auto toReplace = linalgOp.getBlock()->getOps<AffineApplyOp>();
+
+ for (auto op : make_early_inc_range(toReplace)) {
+ auto expanded =
+ expandAffineExpr(rewriter, op->getLoc(), op.getAffineMap().getResult(0),
+ op.getOperands(), ValueRange{});
+ rewriter.replaceOp(op, expanded);
+ }
+}
+
/// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes`
/// are used to vectorize this operation. `inputVectorSizes` must match the rank
/// of the iteration space of the operation and the sizes must be smaller or
@@ -1084,6 +1100,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp,
vectorizeNDExtract)))
return failure();
LDBG("Vectorize generic by broadcasting to the canonical vector shape\n");
+
+ // Pre-process before proceeding.
+ convertAffineApply(rewriter, linalgOp);
+
// TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted to
// 'OpBuilder' when it is passed over to some methods like
// 'vectorizeAsLinalgGeneric'. This is highly problematic: if we erase an op
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index e02702a0ffa7b..c45c34c65d09e 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -164,7 +164,7 @@ bool hasOnlyScalarElementwiseOp(Region &r) {
return false;
for (Operation &op : r.front()) {
if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp,
- linalg::YieldOp, linalg::IndexOp>(op) ||
+ linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) ||
OpTrait::hasElementwiseMappableTraits(&op)) ||
llvm::any_of(op.getResultTypes(),
[](Type type) { return !type.isIntOrIndexOrFloat(); }))
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 171a518447697..f966b0e241159 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -290,6 +290,43 @@ transform.sequence failures(propagate) {
// -----
+#map0 = affine_map<(d0) -> (d0)>
+
+func.func @vectorize_affine_apply(%arg0: tensor<32xf32>, %arg3: index) -> tensor<32xi32> {
+ %0 = tensor.empty() : tensor<32xi32>
+ %1 = linalg.generic {indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel"]}
+ ins(%arg0 : tensor<32xf32>)
+ outs(%0 : tensor<32xi32>) {
+ ^bb0(%arg1: f32, %arg2: i32):
+ %2 = linalg.index 0 : index
+ %12 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg3)
+ %3 = arith.index_cast %12 : index to i32
+ linalg.yield %3 : i32
+ } -> tensor<32xi32>
+ return %1 : tensor<32xi32>
+}
+
+// CHECK-LABEL: func.func @vectorize_affine_apply
+// CHECK-SAME: %arg0: tensor<32xf32>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]> : vector<32xindex>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<32xi32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG1]] : index to vector<32xindex>
+// CHECK: %[[ADDI:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<32xindex>
+// CHECK: %[[CAST:.*]] = arith.index_cast %[[ADDI]] : vector<32xindex> to vector<32xi32>
+// CHECK: vector.transfer_write %[[CAST]], %[[EMPTY]][%[[C0:.*]]] {in_bounds = [true]} : vector<32xi32>, tensor<32xi32>
+
+transform.sequence failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+ %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
+}
+
+// -----
+
// CHECK-LABEL: func @test_vectorize_fill
func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
// CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
More information about the Mlir-commits
mailing list