[Mlir-commits] [mlir] [MLIR][Linalg] Fix insert_slice fusion with rank reduction (PR #130961)
Thomas Preud'homme
llvmlistbot at llvm.org
Wed May 21 01:21:49 PDT 2025
https://github.com/RoboTux updated https://github.com/llvm/llvm-project/pull/130961
>From a370cd2d44b2715470c49dfb8b013d12dcff9826 Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Wed, 12 Mar 2025 13:22:14 +0000
Subject: [PATCH 1/8] [MLIR][Linalg] Fix insert_slice fusion with rank
reduction
Insert_slice fusion with a linalg producer does not account for
possible rank-reduction in the insert_slice return type. When that
happens, a tosa.cast gets generated due to the type mismatch which is
invalid for tensor with different rank. This later trips other pass.
---
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 36 ++++++++++-
.../Dialect/Linalg/tile-and-fuse-tensors.mlir | 63 +++++++++++++++++++
2 files changed, 97 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 223d728b0b27d..81b204df5a0aa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Dominance.h"
@@ -26,6 +27,7 @@
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallBitVector.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -235,6 +237,31 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
}
+/// Create tensor.collapse_shape to drop dimensions in `dropDims` in tensor
+/// `from`.
+tensor::CollapseShapeOp collapseTo(OpBuilder &b, Location loc, Value from,
+ const llvm::SmallBitVector &dropDims) {
+ auto fromType = cast<ShapedType>(from.getType());
+ assert(fromType.getRank() == dropDims.size());
+ SmallVector<ReassociationIndices, 2> reassocIdxsVec;
+ ReassociationIndices reassocIdxs;
+
+ bool foundKeptDim = false;
+ for (int dim = 0; dim < fromType.getRank(); dim++) {
+ if (!dropDims.test(dim)) {
+ if (foundKeptDim) {
+ reassocIdxsVec.push_back(reassocIdxs);
+ reassocIdxs.clear();
+ }
+ foundKeptDim = true;
+ }
+ reassocIdxs.push_back(dim);
+ }
+ if (!reassocIdxs.empty())
+ reassocIdxsVec.push_back(reassocIdxs);
+ return b.create<tensor::CollapseShapeOp>(loc, from, reassocIdxsVec);
+}
+
FailureOr<FusionInfo>
mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
OpOperand &consumerOpOperand) {
@@ -255,6 +282,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
<< "\nNot fusable, not an extract_slice op: " << inputTensor);
return failure();
}
+ llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
// If producer is already in the same block as consumer, we are done.
if (consumerOpOperand.get().getParentBlock() ==
@@ -272,12 +300,16 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
consumerOpOperand);
// Replace use.
+ Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
+ Type consumerType = consumerOpOperand.get().getType();
+ // Rank-reduction occured as part of the extract_slice.
+ if (cast<ShapedType>(consumerType).getRank() !=
+ cast<ShapedType>(def.getType()).getRank())
+ def = collapseTo(b, fusedProducer.getLoc(), def, droppedDims);
// Canonicalizations are not guaranteed to have happened before constructing
// `fusedProducer`. In the tensor case this can result in temporary type
// mismatches. Insert a `tensor.cast` op to propagate the transformation
// invariant that types are compatible.
- Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
- Type consumerType = consumerOpOperand.get().getType();
if (consumerType != def.getType())
def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
consumerOpOperand.set(def);
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index 0f27a92c119cf..b4fbdfacde899 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -318,3 +318,66 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
}
return %for0 : tensor<64x128xf32>
}
+
+// -----
+
+func.func @rank_reduced_extract_slice(%arg0: tensor<6x6x1x1x1x1xf32>, %arg1: tensor<6x6x1x1xf32>, %arg2: tensor<4x6xf32>) -> tensor<4x6xf32> {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c6 = arith.constant 6 : index
+ %cst = arith.constant 0.0 : f32
+ %init1 = tensor.empty() : tensor<6x6x1x1x1x1xf32>
+ %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d6, d5)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<6x6x1x1x1x1xf32>, tensor<6x6x1x1xf32>) outs(%fill1 : tensor<6x6x1x1x1x1xf32>) {
+ ^bb0(%in: f32, %in_1: f32, %out: f32):
+ %10 = arith.mulf %in, %in_1 : f32
+ %11 = arith.addf %out, %10 : f32
+ linalg.yield %11 : f32
+ } -> tensor<6x6x1x1x1x1xf32>
+ %init2 = tensor.empty() : tensor<4x6xf32>
+ %1 = scf.for %arg4 = %c0 to %c6 step %c2 iter_args(%arg3 = %init2) -> (tensor<4x6xf32>) {
+ %2 = tensor.extract_slice %0[0, %arg4, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32>
+ %init3 = tensor.empty() : tensor<4x2xf32>
+ %fill3 = linalg.fill ins(%cst : f32) outs(%init3 : tensor<4x2xf32>) -> tensor<4x2xf32>
+ %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%fill3 : tensor<4x2xf32>) {
+ ^bb0(%in: f32, %in_1: f32, %out: f32):
+ %20 = arith.mulf %in, %in_1 : f32
+ %21 = arith.addf %out, %20 : f32
+ linalg.yield %21 : f32
+ } -> tensor<4x2xf32>
+ %4 = tensor.insert_slice %3 into %arg3[0, %arg4] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
+ scf.yield %4 : tensor<4x6xf32>
+ }
+ return %1 : tensor<4x6xf32>
+}
+
+// CHECK: func @rank_reduced_extract_slice(
+// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<6x6x1x1x1x1xf32>
+// CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<6x6x1x1xf32>
+// CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32>
+
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+// CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32>
+// CHECK: %[[FILL_PROD:.*]] = linalg.fill ins({{%.*}} : f32)
+// CHECK-SAME: outs(%[[EMPTY_PROD]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
+// CHECK: %[[EMPTY_FOR:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<4x2xf32>
+// CHECK: %[[FILL_CONS:.*]] = linalg.fill ins({{%.*}} : f32)
+// CHECK-SAME: outs(%[[EMPTY_CONS]] : tensor<4x2xf32>) -> tensor<4x2xf32>
+// CHECK: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[EMPTY_FOR]])
+// CHECK-DAG: %[[ARG0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[I]], 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32>
+// CHECK-DAG: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[I]], 0, 0] [6, 2, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> to tensor<6x2x1x1xf32>
+// CHECK-DAG: %[[FILL_PROD_SLICE:.*]] = tensor.extract_slice %[[FILL_PROD]][0, %[[I]], 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32>
+
+// CHECK: %[[MMUL_PROD:.*]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0_SLICE]], %[[ARG1_SLICE]] : tensor<6x2x1x1x1x1xf32>, tensor<6x2x1x1xf32>)
+// CHECK-SAME: outs(%[[FILL_PROD_SLICE]] : tensor<6x2x1x1x1x1xf32>)
+// CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0\], \[1, 2, 3, 4, 5\]\]}} : tensor<6x2x1x1x1x1xf32> into tensor<6x2xf32>
+// CHECK: %[[MMUL_CONS:.*]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
+// CHECK-SAME: outs(%[[FILL_CONS]] : tensor<4x2xf32>)
+// CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
+// CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32>
+// CHECK: return %[[FOR]] : tensor<4x6xf32>
>From ce067327de33ee88397573ead73291764f15c627 Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Tue, 25 Mar 2025 22:49:26 +0000
Subject: [PATCH 2/8] Add more comments and simplify test
---
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 12 ++-
.../Dialect/Linalg/tile-and-fuse-tensors.mlir | 88 ++++++++-----------
2 files changed, 45 insertions(+), 55 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 81b204df5a0aa..d18d6f7ff8dd8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -239,14 +239,20 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
/// Create tensor.collapse_shape to drop dimensions in `dropDims` in tensor
/// `from`.
-tensor::CollapseShapeOp collapseTo(OpBuilder &b, Location loc, Value from,
- const llvm::SmallBitVector &dropDims) {
+static tensor::CollapseShapeOp collapseTo(OpBuilder &b, Location loc, Value from,
+ const llvm::SmallBitVector &dropDims) {
auto fromType = cast<ShapedType>(from.getType());
- assert(fromType.getRank() == dropDims.size());
+ assert(fromType.getRank() == dropDims.size() && "dropDims dimension does not match from tensor rank");
+ // Computed reassociation map for the corresponding tensor.collapse_shape.
SmallVector<ReassociationIndices, 2> reassocIdxsVec;
+ // Current reassociation indices to add dropped dimension to.
ReassociationIndices reassocIdxs;
bool foundKeptDim = false;
+ // Dropped dimensions might be at the beginning or end of the shape so
+ // combine all contiguous dimensions before and after a given non dropped
+ // dimension in reassocIdxs until another non dropped dimension is found.
+ // When that happens, add the reassociation indices to the map.
for (int dim = 0; dim < fromType.getRank(); dim++) {
if (!dropDims.test(dim)) {
if (foundKeptDim) {
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index b4fbdfacde899..46b70a9c0edba 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -321,63 +321,47 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
// -----
-func.func @rank_reduced_extract_slice(%arg0: tensor<6x6x1x1x1x1xf32>, %arg1: tensor<6x6x1x1xf32>, %arg2: tensor<4x6xf32>) -> tensor<4x6xf32> {
- %c0 = arith.constant 0 : index
- %c2 = arith.constant 2 : index
- %c6 = arith.constant 6 : index
+func.func @rank_reduced_extract_slice(%cond : i1) -> tensor<6x2xf32> {
%cst = arith.constant 0.0 : f32
- %init1 = tensor.empty() : tensor<6x6x1x1x1x1xf32>
- %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
- %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d6, d5)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<6x6x1x1x1x1xf32>, tensor<6x6x1x1xf32>) outs(%fill1 : tensor<6x6x1x1x1x1xf32>) {
- ^bb0(%in: f32, %in_1: f32, %out: f32):
- %10 = arith.mulf %in, %in_1 : f32
- %11 = arith.addf %out, %10 : f32
- linalg.yield %11 : f32
+ %cst1 = arith.constant 1.0 : f32
+
+ %empty1 = tensor.empty() : tensor<6x6x1x1x1x1xf32>
+ %init1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} outs(%empty1 : tensor<6x6x1x1x1x1xf32>) {
+ ^bb0(%out: f32):
+ linalg.yield %cst : f32
} -> tensor<6x6x1x1x1x1xf32>
- %init2 = tensor.empty() : tensor<4x6xf32>
- %1 = scf.for %arg4 = %c0 to %c6 step %c2 iter_args(%arg3 = %init2) -> (tensor<4x6xf32>) {
- %2 = tensor.extract_slice %0[0, %arg4, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32>
- %init3 = tensor.empty() : tensor<4x2xf32>
- %fill3 = linalg.fill ins(%cst : f32) outs(%init3 : tensor<4x2xf32>) -> tensor<4x2xf32>
- %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%fill3 : tensor<4x2xf32>) {
- ^bb0(%in: f32, %in_1: f32, %out: f32):
- %20 = arith.mulf %in, %in_1 : f32
- %21 = arith.addf %out, %20 : f32
- linalg.yield %21 : f32
- } -> tensor<4x2xf32>
- %4 = tensor.insert_slice %3 into %arg3[0, %arg4] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
- scf.yield %4 : tensor<4x6xf32>
+
+ %if = scf.if %cond -> tensor<6x2xf32> {
+ %extract0 = tensor.extract_slice %init1[0, 0, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32>
+
+ %init2 = tensor.empty() : tensor<6x2xf32>
+ %add1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extract0 : tensor<6x2xf32>) outs(%init2 : tensor<6x2xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %add = arith.addf %in, %cst1 : f32
+ linalg.yield %add : f32
+ } -> tensor<6x2xf32>
+ scf.yield %add1 : tensor<6x2xf32>
+ } else {
+ %extract2 = tensor.extract_slice %init1[0, 2, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32>
+ scf.yield %extract2 : tensor<6x2xf32>
}
- return %1 : tensor<4x6xf32>
+
+ return %if : tensor<6x2xf32>
}
// CHECK: func @rank_reduced_extract_slice(
-// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<6x6x1x1x1x1xf32>
-// CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<6x6x1x1xf32>
-// CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32>
+// CHECK-SAME: %[[COND:[0-9a-z]*]]: i1
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
// CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32>
-// CHECK: %[[FILL_PROD:.*]] = linalg.fill ins({{%.*}} : f32)
-// CHECK-SAME: outs(%[[EMPTY_PROD]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
-// CHECK: %[[EMPTY_FOR:.*]] = tensor.empty() : tensor<4x6xf32>
-// CHECK: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<4x2xf32>
-// CHECK: %[[FILL_CONS:.*]] = linalg.fill ins({{%.*}} : f32)
-// CHECK-SAME: outs(%[[EMPTY_CONS]] : tensor<4x2xf32>) -> tensor<4x2xf32>
-// CHECK: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[EMPTY_FOR]])
-// CHECK-DAG: %[[ARG0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[I]], 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32>
-// CHECK-DAG: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[I]], 0, 0] [6, 2, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> to tensor<6x2x1x1xf32>
-// CHECK-DAG: %[[FILL_PROD_SLICE:.*]] = tensor.extract_slice %[[FILL_PROD]][0, %[[I]], 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32>
-
-// CHECK: %[[MMUL_PROD:.*]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0_SLICE]], %[[ARG1_SLICE]] : tensor<6x2x1x1x1x1xf32>, tensor<6x2x1x1xf32>)
-// CHECK-SAME: outs(%[[FILL_PROD_SLICE]] : tensor<6x2x1x1x1x1xf32>)
-// CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0\], \[1, 2, 3, 4, 5\]\]}} : tensor<6x2x1x1x1x1xf32> into tensor<6x2xf32>
-// CHECK: %[[MMUL_CONS:.*]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
-// CHECK-SAME: outs(%[[FILL_CONS]] : tensor<4x2xf32>)
-// CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
-// CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32>
-// CHECK: return %[[FOR]] : tensor<4x6xf32>
+// CHECK: %[[FILL_PROD:.*]] = linalg.generic
+// CHECK-SAME: outs(%[[EMPTY_PROD]] : tensor<6x6x1x1x1x1xf32>)
+
+// CHECK: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<6x2xf32>
+// CHECK: %[[EXTRACT_SLICE_CONS:.*]] = tensor.extract_slice %[[EMPTY_PROD]][0, 0, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32>
+
+// CHECK: %[[FILL_CONS:.*]] = linalg.generic
+// CHECK-SAME: outs(%[[EXTRACT_SLICE_CONS]] : tensor<6x2x1x1x1x1xf32>)
+// CHECK: %[[CONS_COLLAPSE:.*]] = tensor.collapse_shape %[[FILL_CONS]] {{\[\[0\], \[1, 2, 3, 4, 5\]\]}} : tensor<6x2x1x1x1x1xf32> into tensor<6x2xf32>
+// CHECK: %[[ADD1_CONS:.*]] = linalg.generic
+// CHECK-SAME: ins(%[[CONS_COLLAPSE]] : tensor<6x2xf32>)
+// CHECK-SAME: outs(%[[EMPTY_CONS]] : tensor<6x2xf32>)
>From b51029d13fb67e2860fdf3832f6b66ea1b30544f Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Wed, 26 Mar 2025 00:01:55 +0000
Subject: [PATCH 3/8] Fix clang-format
---
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index d18d6f7ff8dd8..bcb21263ee68f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -239,10 +239,12 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
/// Create tensor.collapse_shape to drop dimensions in `dropDims` in tensor
/// `from`.
-static tensor::CollapseShapeOp collapseTo(OpBuilder &b, Location loc, Value from,
- const llvm::SmallBitVector &dropDims) {
+static tensor::CollapseShapeOp
+collapseTo(OpBuilder &b, Location loc, Value from,
+ const llvm::SmallBitVector &dropDims) {
auto fromType = cast<ShapedType>(from.getType());
- assert(fromType.getRank() == dropDims.size() && "dropDims dimension does not match from tensor rank");
+ assert(fromType.getRank() == dropDims.size() &&
+ "dropDims dimension does not match from tensor rank");
// Computed reassociation map for the corresponding tensor.collapse_shape.
SmallVector<ReassociationIndices, 2> reassocIdxsVec;
// Current reassociation indices to add dropped dimension to.
>From b19b6490e71181c7e956223b3e4f33330fd0fde7 Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Wed, 7 May 2025 10:56:37 +0100
Subject: [PATCH 4/8] Address comments
- rename collapseTo to better reflect its usage
- assert it only collapse unit dimensions
- rename ReassociationIndices-using variables to reassocGroup and
reassocMaps, the same terminology used in tensor.collapse_shape
documentation
- use more representative test with comments to better explain what the
patch does
---
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 34 +++---
.../Dialect/Linalg/tile-and-fuse-tensors.mlir | 107 +++++++++++-------
2 files changed, 87 insertions(+), 54 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 3655c43940b06..c52f347135a9e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -236,37 +236,39 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
}
-/// Create tensor.collapse_shape to drop dimensions in `dropDims` in tensor
+/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor
/// `from`.
static tensor::CollapseShapeOp
-collapseTo(OpBuilder &b, Location loc, Value from,
- const llvm::SmallBitVector &dropDims) {
+dropGivenUnitDims(OpBuilder &b, Location loc, Value from,
+ const llvm::SmallBitVector &dropDims) {
auto fromType = cast<ShapedType>(from.getType());
- assert(fromType.getRank() == dropDims.size() &&
+ assert(fromType.getRank() == static_cast<int64_t>(dropDims.size()) &&
"dropDims dimension does not match from tensor rank");
// Computed reassociation map for the corresponding tensor.collapse_shape.
- SmallVector<ReassociationIndices, 2> reassocIdxsVec;
- // Current reassociation indices to add dropped dimension to.
- ReassociationIndices reassocIdxs;
+ SmallVector<ReassociationIndices, 2> reassocMaps;
+ // Current reassociation group to add dropped dimension to.
+ ReassociationIndices reassocGroup;
bool foundKeptDim = false;
// Dropped dimensions might be at the beginning or end of the shape so
// combine all contiguous dimensions before and after a given non dropped
- // dimension in reassocIdxs until another non dropped dimension is found.
+ // dimension in reassocGroup until another non dropped dimension is found.
// When that happens, add the reassociation indices to the map.
for (int dim = 0; dim < fromType.getRank(); dim++) {
- if (!dropDims.test(dim)) {
+ if (dropDims.test(dim))
+ assert(fromType.getShape()[dim] == 1 && "Dropping non unit dimension");
+ else {
if (foundKeptDim) {
- reassocIdxsVec.push_back(reassocIdxs);
- reassocIdxs.clear();
+ reassocMaps.push_back(reassocGroup);
+ reassocGroup.clear();
}
foundKeptDim = true;
}
- reassocIdxs.push_back(dim);
+ reassocGroup.push_back(dim);
}
- if (!reassocIdxs.empty())
- reassocIdxsVec.push_back(reassocIdxs);
- return b.create<tensor::CollapseShapeOp>(loc, from, reassocIdxsVec);
+ if (!reassocGroup.empty())
+ reassocMaps.push_back(reassocGroup);
+ return b.create<tensor::CollapseShapeOp>(loc, from, reassocMaps);
}
FailureOr<FusionInfo>
@@ -312,7 +314,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
// Rank-reduction occured as part of the extract_slice.
if (cast<ShapedType>(consumerType).getRank() !=
cast<ShapedType>(def.getType()).getRank())
- def = collapseTo(b, fusedProducer.getLoc(), def, droppedDims);
+ def = dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims);
// Canonicalizations are not guaranteed to have happened before constructing
// `fusedProducer`. In the tensor case this can result in temporary type
// mismatches. Insert a `tensor.cast` op to propagate the transformation
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index 46b70a9c0edba..693a2bb29f76e 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -321,47 +321,78 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
// -----
-func.func @rank_reduced_extract_slice(%cond : i1) -> tensor<6x2xf32> {
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map4 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @rank_reduced_extract_slice(%arg0: tensor<1x6x5xf32>, %arg1: tensor<1x5x6xf32>, %arg2: tensor<4x6xf32>) -> tensor<4x6xf32> {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c6 = arith.constant 6 : index
%cst = arith.constant 0.0 : f32
- %cst1 = arith.constant 1.0 : f32
-
- %empty1 = tensor.empty() : tensor<6x6x1x1x1x1xf32>
- %init1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} outs(%empty1 : tensor<6x6x1x1x1x1xf32>) {
- ^bb0(%out: f32):
- linalg.yield %cst : f32
- } -> tensor<6x6x1x1x1x1xf32>
-
- %if = scf.if %cond -> tensor<6x2xf32> {
- %extract0 = tensor.extract_slice %init1[0, 0, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32>
-
- %init2 = tensor.empty() : tensor<6x2xf32>
- %add1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extract0 : tensor<6x2xf32>) outs(%init2 : tensor<6x2xf32>) {
- ^bb0(%in: f32, %out: f32):
- %add = arith.addf %in, %cst1 : f32
- linalg.yield %add : f32
- } -> tensor<6x2xf32>
- scf.yield %add1 : tensor<6x2xf32>
- } else {
- %extract2 = tensor.extract_slice %init1[0, 2, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32>
- scf.yield %extract2 : tensor<6x2xf32>
+ %init1 = tensor.empty() : tensor<1x6x6xf32>
+ %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<1x6x6xf32>) -> tensor<1x6x6xf32>
+ %0 = linalg.generic
+ {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%fill1 : tensor<1x6x6xf32>) {
+ ^bb0(%in: f32, %in_1: f32, %out: f32):
+ %10 = arith.mulf %in, %in_1 : f32
+ %11 = arith.addf %out, %10 : f32
+ linalg.yield %11 : f32
+ } -> tensor<1x6x6xf32>
+ %init2 = tensor.empty() : tensor<4x6xf32>
+ %1 = scf.for %arg4 = %c0 to %c6 step %c2 iter_args(%arg3 = %init2) -> (tensor<4x6xf32>) {
+ %2 = tensor.extract_slice %0[0, 0, %arg4] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32>
+ %init3 = tensor.empty() : tensor<4x2xf32>
+ %fill3 = linalg.fill ins(%cst : f32) outs(%init3 : tensor<4x2xf32>) -> tensor<4x2xf32>
+ %3 = linalg.generic
+ {indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%fill3 : tensor<4x2xf32>) {
+ ^bb0(%in: f32, %in_1: f32, %out: f32):
+ %20 = arith.mulf %in, %in_1 : f32
+ %21 = arith.addf %out, %20 : f32
+ linalg.yield %21 : f32
+ } -> tensor<4x2xf32>
+ %4 = tensor.insert_slice %3 into %arg3[0, %arg4] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
+ scf.yield %4 : tensor<4x6xf32>
}
-
- return %if : tensor<6x2xf32>
+ return %1 : tensor<4x6xf32>
}
// CHECK: func @rank_reduced_extract_slice(
-// CHECK-SAME: %[[COND:[0-9a-z]*]]: i1
-
-// CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32>
-// CHECK: %[[FILL_PROD:.*]] = linalg.generic
-// CHECK-SAME: outs(%[[EMPTY_PROD]] : tensor<6x6x1x1x1x1xf32>)
+// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<1x6x5xf32>
+// CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<1x5x6xf32>
+// CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32>
-// CHECK: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<6x2xf32>
-// CHECK: %[[EXTRACT_SLICE_CONS:.*]] = tensor.extract_slice %[[EMPTY_PROD]][0, 0, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32>
-
-// CHECK: %[[FILL_CONS:.*]] = linalg.generic
-// CHECK-SAME: outs(%[[EXTRACT_SLICE_CONS]] : tensor<6x2x1x1x1x1xf32>)
-// CHECK: %[[CONS_COLLAPSE:.*]] = tensor.collapse_shape %[[FILL_CONS]] {{\[\[0\], \[1, 2, 3, 4, 5\]\]}} : tensor<6x2x1x1x1x1xf32> into tensor<6x2xf32>
-// CHECK: %[[ADD1_CONS:.*]] = linalg.generic
-// CHECK-SAME: ins(%[[CONS_COLLAPSE]] : tensor<6x2xf32>)
-// CHECK-SAME: outs(%[[EMPTY_CONS]] : tensor<6x2xf32>)
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<1x6x6xf32>
+// CHECK-NEXT: %[[FILL_PROD:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY_PROD]] : tensor<1x6x6xf32>) -> tensor<1x6x6xf32>
+// CHECK-NEXT: %[[EMPTY_FOR:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<4x2xf32>
+// CHECK-NEXT: %[[FILL_CONS:.*]] = linalg.fill ins(%[[CST]] : f32)
+
+// For loop right after tensor alloc & fill, no linalg.generic.
+// CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[EMPTY_FOR]])
+
+// Producer linalg.generic now inside the loop, with tiled args sliced before
+// it.
+// CHECK-DAG: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1] : tensor<1x5x6xf32> to tensor<1x5x2xf32>
+// CHECK-DAG: %[[PROD_SLICE:.*]] = tensor.extract_slice %[[FILL_PROD]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32>
+// CHECK: %[[MMUL_PROD:.*]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1_SLICE]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>)
+// CHECK-SAME: outs(%[[PROD_SLICE]] : tensor<1x6x2xf32>)
+//
+// Consumer uses a rank-reduced version of producer result so a collapse_shape
+// is generated.
+// CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0, 1\], \[2\]\]}} : tensor<1x6x2xf32> into tensor<6x2xf32>
+// CHECK: %[[MMUL_CONS:.*]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
+// CHECK-SAME: outs(%[[FILL_CONS]] : tensor<4x2xf32>)
+// CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
+// CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32>
+// CHECK: return %[[FOR]] : tensor<4x6xf32>
>From 49deedf97d448f7bb71d47da0b0a1bbe7b53db96 Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Mon, 19 May 2025 23:05:52 +0100
Subject: [PATCH 5/8] Clean up code
dropGivenUnitDims():
- move assert out of loop
- rework algorithm to make grouping more explicit and avoid complex
nested ifs
- fix occured typo
Test: remove all tensor.empty and linalg.fill
---
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 42 +++++++++----------
.../Dialect/Linalg/tile-and-fuse-tensors.mlir | 37 +++++++---------
2 files changed, 36 insertions(+), 43 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index c52f347135a9e..d69c85984aa4e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -242,32 +242,30 @@ static tensor::CollapseShapeOp
dropGivenUnitDims(OpBuilder &b, Location loc, Value from,
const llvm::SmallBitVector &dropDims) {
auto fromType = cast<ShapedType>(from.getType());
- assert(fromType.getRank() == static_cast<int64_t>(dropDims.size()) &&
+ int64_t rank = fromType.getRank();
+ assert(rank == static_cast<int64_t>(dropDims.size()) &&
"dropDims dimension does not match from tensor rank");
+ assert(llvm::all_of(
+ dropDims.set_bits(),
+ [&](unsigned dim) { return fromType.getShape()[dim] == 1; }) &&
+ "Dropping non unit dimension");
// Computed reassociation map for the corresponding tensor.collapse_shape.
SmallVector<ReassociationIndices, 2> reassocMaps;
// Current reassociation group to add dropped dimension to.
- ReassociationIndices reassocGroup;
-
- bool foundKeptDim = false;
- // Dropped dimensions might be at the beginning or end of the shape so
- // combine all contiguous dimensions before and after a given non dropped
- // dimension in reassocGroup until another non dropped dimension is found.
- // When that happens, add the reassociation indices to the map.
- for (int dim = 0; dim < fromType.getRank(); dim++) {
- if (dropDims.test(dim))
- assert(fromType.getShape()[dim] == 1 && "Dropping non unit dimension");
- else {
- if (foundKeptDim) {
- reassocMaps.push_back(reassocGroup);
- reassocGroup.clear();
- }
- foundKeptDim = true;
- }
- reassocGroup.push_back(dim);
+
+ int64_t nextDimToGroup = 0;
+ llvm::SmallBitVector keptDims(dropDims);
+ keptDims.flip();
+ int64_t lastSetBit = keptDims.find_last();
+ for(int64_t setBit : keptDims.set_bits()) {
+ // Group consecutive dropped dimension with the next non-dropped dimension.
+ // If this is the last set dimension, also group all subsequent dropped
+ // dimension, if any.
+ int64_t upTo = setBit == lastSetBit ? rank - 1 : setBit;
+ auto seq = llvm::seq_inclusive(nextDimToGroup, upTo);
+ reassocMaps.emplace_back(llvm::make_range(seq.begin(), seq.end()));
+ nextDimToGroup = setBit + 1;
}
- if (!reassocGroup.empty())
- reassocMaps.push_back(reassocGroup);
return b.create<tensor::CollapseShapeOp>(loc, from, reassocMaps);
}
@@ -311,7 +309,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
// Replace use.
Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
Type consumerType = consumerOpOperand.get().getType();
- // Rank-reduction occured as part of the extract_slice.
+ // Rank-reduction occurred as part of the extract_slice.
if (cast<ShapedType>(consumerType).getRank() !=
cast<ShapedType>(def.getType()).getRank())
def = dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims);
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index 693a2bb29f76e..9340e70b4d507 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -327,35 +327,32 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
#map3 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map4 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @rank_reduced_extract_slice(%arg0: tensor<1x6x5xf32>, %arg1: tensor<1x5x6xf32>, %arg2: tensor<4x6xf32>) -> tensor<4x6xf32> {
+func.func @rank_reduced_extract_slice(
+ %arg0: tensor<1x6x5xf32>, %arg1: tensor<1x5x6xf32>, %arg2: tensor<4x6xf32>,
+ %arg3: tensor<1x6x6xf32>, %arg4: tensor<4x6xf32>, %arg5: tensor<4x2xf32>
+) -> tensor<4x6xf32> {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c6 = arith.constant 6 : index
- %cst = arith.constant 0.0 : f32
- %init1 = tensor.empty() : tensor<1x6x6xf32>
- %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<1x6x6xf32>) -> tensor<1x6x6xf32>
%0 = linalg.generic
{indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
- ins(%arg0, %arg1 : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%fill1 : tensor<1x6x6xf32>) {
+ ins(%arg0, %arg1 : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%arg3 : tensor<1x6x6xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%10 = arith.mulf %in, %in_1 : f32
%11 = arith.addf %out, %10 : f32
linalg.yield %11 : f32
} -> tensor<1x6x6xf32>
- %init2 = tensor.empty() : tensor<4x6xf32>
- %1 = scf.for %arg4 = %c0 to %c6 step %c2 iter_args(%arg3 = %init2) -> (tensor<4x6xf32>) {
- %2 = tensor.extract_slice %0[0, 0, %arg4] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32>
- %init3 = tensor.empty() : tensor<4x2xf32>
- %fill3 = linalg.fill ins(%cst : f32) outs(%init3 : tensor<4x2xf32>) -> tensor<4x2xf32>
+ %1 = scf.for %arg7 = %c0 to %c6 step %c2 iter_args(%arg6 = %arg4) -> (tensor<4x6xf32>) {
+ %2 = tensor.extract_slice %0[0, 0, %arg7] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32>
%3 = linalg.generic
{indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"]}
- ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%fill3 : tensor<4x2xf32>) {
+ ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%arg5 : tensor<4x2xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%20 = arith.mulf %in, %in_1 : f32
%21 = arith.addf %out, %20 : f32
linalg.yield %21 : f32
} -> tensor<4x2xf32>
- %4 = tensor.insert_slice %3 into %arg3[0, %arg4] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
+ %4 = tensor.insert_slice %3 into %arg6[0, %arg7] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
scf.yield %4 : tensor<4x6xf32>
}
return %1 : tensor<4x6xf32>
@@ -365,24 +362,22 @@ func.func @rank_reduced_extract_slice(%arg0: tensor<1x6x5xf32>, %arg1: tensor<1x
// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<1x6x5xf32>
// CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<1x5x6xf32>
// CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32>
+// CHECK-SAME: %[[ARG3:[0-9a-z]*]]: tensor<1x6x6xf32>
+// CHECK-SAME: %[[ARG4:[0-9a-z]*]]: tensor<4x6xf32>
+// CHECK-SAME: %[[ARG5:[0-9a-z]*]]: tensor<4x2xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
-// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<1x6x6xf32>
-// CHECK-NEXT: %[[FILL_PROD:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY_PROD]] : tensor<1x6x6xf32>) -> tensor<1x6x6xf32>
-// CHECK-NEXT: %[[EMPTY_FOR:.*]] = tensor.empty() : tensor<4x6xf32>
-// CHECK-NEXT: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<4x2xf32>
-// CHECK-NEXT: %[[FILL_CONS:.*]] = linalg.fill ins(%[[CST]] : f32)
// For loop right after tensor alloc & fill, no linalg.generic.
-// CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[EMPTY_FOR]])
+// CHECK-NOT: linalg.generic
+// CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[ARG4]])
// Producer linalg.generic now inside the loop, with tiled args sliced before
// it.
// CHECK-DAG: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1] : tensor<1x5x6xf32> to tensor<1x5x2xf32>
-// CHECK-DAG: %[[PROD_SLICE:.*]] = tensor.extract_slice %[[FILL_PROD]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32>
+// CHECK-DAG: %[[PROD_SLICE:.*]] = tensor.extract_slice %[[ARG3]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32>
// CHECK: %[[MMUL_PROD:.*]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1_SLICE]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>)
// CHECK-SAME: outs(%[[PROD_SLICE]] : tensor<1x6x2xf32>)
@@ -392,7 +387,7 @@ func.func @rank_reduced_extract_slice(%arg0: tensor<1x6x5xf32>, %arg1: tensor<1x
// CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0, 1\], \[2\]\]}} : tensor<1x6x2xf32> into tensor<6x2xf32>
// CHECK: %[[MMUL_CONS:.*]] = linalg.generic
// CHECK-SAME: ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
-// CHECK-SAME: outs(%[[FILL_CONS]] : tensor<4x2xf32>)
+// CHECK-SAME: outs(%[[ARG5]] : tensor<4x2xf32>)
// CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
// CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32>
// CHECK: return %[[FOR]] : tensor<4x6xf32>
>From cf20e80c175bbfd6990aa3d58f0d972e796f7510 Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Tue, 20 May 2025 11:33:20 +0100
Subject: [PATCH 6/8] Fix codestyle
---
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index d69c85984aa4e..f983fb5e40fa7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -257,7 +257,7 @@ dropGivenUnitDims(OpBuilder &b, Location loc, Value from,
llvm::SmallBitVector keptDims(dropDims);
keptDims.flip();
int64_t lastSetBit = keptDims.find_last();
- for(int64_t setBit : keptDims.set_bits()) {
+ for (int64_t setBit : keptDims.set_bits()) {
// Group consecutive dropped dimension with the next non-dropped dimension.
// If this is the last set dimension, also group all subsequent dropped
// dimension, if any.
>From 42d8959bc2594be9218c1e709320e048cde8cc84 Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Tue, 20 May 2025 13:26:57 +0100
Subject: [PATCH 7/8] Move dropGivenUnitDims to Tensor Utils
---
.../include/mlir/Dialect/Tensor/Utils/Utils.h | 5 +++
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 36 ++-----------------
mlir/lib/Dialect/Tensor/Utils/Utils.cpp | 33 +++++++++++++++++
3 files changed, 40 insertions(+), 34 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index 22ca8a99dd7db..6c2a55f67db87 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -43,6 +43,11 @@ FailureOr<RankedTensorType>
computeTransposedType(RankedTensorType rankedTensorType,
ArrayRef<int64_t> transposeVector);
+/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor
+/// `from`.
+CollapseShapeOp dropGivenUnitDims(OpBuilder &b, Location loc, Value from,
+ const llvm::SmallBitVector &dropDims);
+
/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
/// source tensor or inserts the source tensor into a destination tensor with
/// the same shape.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index f983fb5e40fa7..e3673e2a385c0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -236,39 +236,6 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
}
-/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor
-/// `from`.
-static tensor::CollapseShapeOp
-dropGivenUnitDims(OpBuilder &b, Location loc, Value from,
- const llvm::SmallBitVector &dropDims) {
- auto fromType = cast<ShapedType>(from.getType());
- int64_t rank = fromType.getRank();
- assert(rank == static_cast<int64_t>(dropDims.size()) &&
- "dropDims dimension does not match from tensor rank");
- assert(llvm::all_of(
- dropDims.set_bits(),
- [&](unsigned dim) { return fromType.getShape()[dim] == 1; }) &&
- "Dropping non unit dimension");
- // Computed reassociation map for the corresponding tensor.collapse_shape.
- SmallVector<ReassociationIndices, 2> reassocMaps;
- // Current reassociation group to add dropped dimension to.
-
- int64_t nextDimToGroup = 0;
- llvm::SmallBitVector keptDims(dropDims);
- keptDims.flip();
- int64_t lastSetBit = keptDims.find_last();
- for (int64_t setBit : keptDims.set_bits()) {
- // Group consecutive dropped dimension with the next non-dropped dimension.
- // If this is the last set dimension, also group all subsequent dropped
- // dimension, if any.
- int64_t upTo = setBit == lastSetBit ? rank - 1 : setBit;
- auto seq = llvm::seq_inclusive(nextDimToGroup, upTo);
- reassocMaps.emplace_back(llvm::make_range(seq.begin(), seq.end()));
- nextDimToGroup = setBit + 1;
- }
- return b.create<tensor::CollapseShapeOp>(loc, from, reassocMaps);
-}
-
FailureOr<FusionInfo>
mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
OpOperand &consumerOpOperand) {
@@ -312,7 +279,8 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
// Rank-reduction occurred as part of the extract_slice.
if (cast<ShapedType>(consumerType).getRank() !=
cast<ShapedType>(def.getType()).getRank())
- def = dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims);
+ def =
+ tensor::dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims);
// Canonicalizations are not guaranteed to have happened before constructing
// `fusedProducer`. In the tensor case this can result in temporary type
// mismatches. Insert a `tensor.cast` op to propagate the transformation
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index c3d56759a896a..53a219dff48c5 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -94,6 +94,39 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
return transposedTensorType;
}
+/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor
+/// `from`.
+CollapseShapeOp
+mlir::tensor::dropGivenUnitDims(OpBuilder &b, Location loc, Value from,
+ const llvm::SmallBitVector &dropDims) {
+ auto fromType = cast<ShapedType>(from.getType());
+ int64_t rank = fromType.getRank();
+ assert(rank == static_cast<int64_t>(dropDims.size()) &&
+ "dropDims dimension does not match from tensor rank");
+ assert(llvm::all_of(
+ dropDims.set_bits(),
+ [&](unsigned dim) { return fromType.getShape()[dim] == 1; }) &&
+ "Dropping non unit dimension");
+ // Computed reassociation map for the corresponding tensor.collapse_shape.
+ SmallVector<ReassociationIndices, 2> reassocMaps;
+ // Current reassociation group to add dropped dimension to.
+
+ int64_t nextDimToGroup = 0;
+ llvm::SmallBitVector keptDims(dropDims);
+ keptDims.flip();
+ int64_t lastSetBit = keptDims.find_last();
+ for (int64_t setBit : keptDims.set_bits()) {
+ // Group consecutive dropped dimension with the next non-dropped dimension.
+ // If this is the last set dimension, also group all subsequent dropped
+ // dimension, if any.
+ int64_t upTo = setBit == lastSetBit ? rank - 1 : setBit;
+ auto seq = llvm::seq_inclusive(nextDimToGroup, upTo);
+ reassocMaps.emplace_back(llvm::make_range(seq.begin(), seq.end()));
+ nextDimToGroup = setBit + 1;
+ }
+ return b.create<tensor::CollapseShapeOp>(loc, from, reassocMaps);
+}
+
bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
llvm::SmallBitVector droppedDims = op.getDroppedDims();
int64_t srcDim = 0;
>From 6fc320ddd5696da921939eae2bea403c37c8568b Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Tue, 20 May 2025 22:43:32 +0100
Subject: [PATCH 8/8] Address review comments
Utils:
- drop comments on implementation
- rename from into src
Fusion:
- restrict live range of droppedDims
- clarify comment for rank-reduction check
Test:
- Use more descriptive SSA and FileCheck variables
- Emphasize the rank-reducing extract_slice in the input IR as the key
aspect of the test.
---
.../include/mlir/Dialect/Tensor/Utils/Utils.h | 4 +-
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 8 +--
mlir/lib/Dialect/Tensor/Utils/Utils.cpp | 14 +++--
.../Dialect/Linalg/tile-and-fuse-tensors.mlir | 51 ++++++++++---------
4 files changed, 41 insertions(+), 36 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index 6c2a55f67db87..1a4733df3f187 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -44,8 +44,8 @@ computeTransposedType(RankedTensorType rankedTensorType,
ArrayRef<int64_t> transposeVector);
/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor
-/// `from`.
-CollapseShapeOp dropGivenUnitDims(OpBuilder &b, Location loc, Value from,
+/// `src`.
+CollapseShapeOp dropGivenUnitDims(OpBuilder &b, Location loc, Value src,
const llvm::SmallBitVector &dropDims);
/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index e3673e2a385c0..4fc8a17554435 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -256,7 +256,6 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
<< "\nNot fusable, not an extract_slice op: " << inputTensor);
return failure();
}
- llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
// If producer is already in the same block as consumer, we are done.
if (consumerOpOperand.get().getParentBlock() ==
@@ -276,11 +275,14 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
// Replace use.
Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
Type consumerType = consumerOpOperand.get().getType();
- // Rank-reduction occurred as part of the extract_slice.
+ // Check if rank-reduction occurred as part of the extract_slice. If yes,
+ // collapse the dropped dimensions.
if (cast<ShapedType>(consumerType).getRank() !=
- cast<ShapedType>(def.getType()).getRank())
+ cast<ShapedType>(def.getType()).getRank()) {
+ llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
def =
tensor::dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims);
+ }
// Canonicalizations are not guaranteed to have happened before constructing
// `fusedProducer`. In the tensor case this can result in temporary type
// mismatches. Insert a `tensor.cast` op to propagate the transformation
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 53a219dff48c5..11ae0108594dd 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -94,18 +94,16 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
return transposedTensorType;
}
-/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor
-/// `from`.
CollapseShapeOp
-mlir::tensor::dropGivenUnitDims(OpBuilder &b, Location loc, Value from,
+mlir::tensor::dropGivenUnitDims(OpBuilder &b, Location loc, Value src,
const llvm::SmallBitVector &dropDims) {
- auto fromType = cast<ShapedType>(from.getType());
- int64_t rank = fromType.getRank();
+ auto srcType = cast<ShapedType>(src.getType());
+ int64_t rank = srcType.getRank();
assert(rank == static_cast<int64_t>(dropDims.size()) &&
- "dropDims dimension does not match from tensor rank");
+ "dropDims dimension does not match src tensor rank");
assert(llvm::all_of(
dropDims.set_bits(),
- [&](unsigned dim) { return fromType.getShape()[dim] == 1; }) &&
+ [&](unsigned dim) { return srcType.getShape()[dim] == 1; }) &&
"Dropping non unit dimension");
// Computed reassociation map for the corresponding tensor.collapse_shape.
SmallVector<ReassociationIndices, 2> reassocMaps;
@@ -124,7 +122,7 @@ mlir::tensor::dropGivenUnitDims(OpBuilder &b, Location loc, Value from,
reassocMaps.emplace_back(llvm::make_range(seq.begin(), seq.end()));
nextDimToGroup = setBit + 1;
}
- return b.create<tensor::CollapseShapeOp>(loc, from, reassocMaps);
+ return b.create<tensor::CollapseShapeOp>(loc, src, reassocMaps);
}
bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index 9340e70b4d507..fd755a208b2c9 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -328,43 +328,48 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
#map4 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
func.func @rank_reduced_extract_slice(
- %arg0: tensor<1x6x5xf32>, %arg1: tensor<1x5x6xf32>, %arg2: tensor<4x6xf32>,
- %arg3: tensor<1x6x6xf32>, %arg4: tensor<4x6xf32>, %arg5: tensor<4x2xf32>
+ %prod_in: tensor<1x6x5xf32>, %prod_weight: tensor<1x5x6xf32>,
+ %cons_in: tensor<4x6xf32>, %prod_init: tensor<1x6x6xf32>,
+ %for_iv_init: tensor<4x6xf32>, %cons_init: tensor<4x2xf32>
) -> tensor<4x6xf32> {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c6 = arith.constant 6 : index
- %0 = linalg.generic
+ %mmul_prod = linalg.generic
{indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
- ins(%arg0, %arg1 : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%arg3 : tensor<1x6x6xf32>) {
+ ins(%prod_in, %prod_weight : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%prod_init : tensor<1x6x6xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%10 = arith.mulf %in, %in_1 : f32
%11 = arith.addf %out, %10 : f32
linalg.yield %11 : f32
} -> tensor<1x6x6xf32>
- %1 = scf.for %arg7 = %c0 to %c6 step %c2 iter_args(%arg6 = %arg4) -> (tensor<4x6xf32>) {
- %2 = tensor.extract_slice %0[0, 0, %arg7] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32>
- %3 = linalg.generic
+ %for = scf.for %arg7 = %c0 to %c6 step %c2 iter_args(%arg6 = %for_iv_init) -> (tensor<4x6xf32>) {
+
+ // Extract slice with rank-reduced result type. When fused in the loop
+ // with sliced operands, the producer linalg must have its now sliced
+ // result be rank-reduced as well to match consumer's use type.
+ %prod_slice = tensor.extract_slice %mmul_prod[0, 0, %arg7] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32>
+ %mmul_cons = linalg.generic
{indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"]}
- ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%arg5 : tensor<4x2xf32>) {
+ ins(%cons_in, %prod_slice : tensor<4x6xf32>, tensor<6x2xf32>) outs(%cons_init : tensor<4x2xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%20 = arith.mulf %in, %in_1 : f32
%21 = arith.addf %out, %20 : f32
linalg.yield %21 : f32
} -> tensor<4x2xf32>
- %4 = tensor.insert_slice %3 into %arg6[0, %arg7] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
+ %4 = tensor.insert_slice %mmul_cons into %arg6[0, %arg7] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
scf.yield %4 : tensor<4x6xf32>
}
- return %1 : tensor<4x6xf32>
+ return %for : tensor<4x6xf32>
}
// CHECK: func @rank_reduced_extract_slice(
-// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<1x6x5xf32>
-// CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<1x5x6xf32>
-// CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32>
-// CHECK-SAME: %[[ARG3:[0-9a-z]*]]: tensor<1x6x6xf32>
-// CHECK-SAME: %[[ARG4:[0-9a-z]*]]: tensor<4x6xf32>
-// CHECK-SAME: %[[ARG5:[0-9a-z]*]]: tensor<4x2xf32>
+// CHECK-SAME: %[[PROD_IN:[0-9a-z]*]]: tensor<1x6x5xf32>
+// CHECK-SAME: %[[PROD_WEIGHT:[0-9a-z]*]]: tensor<1x5x6xf32>
+// CHECK-SAME: %[[CONS_IN:[0-9a-z]*]]: tensor<4x6xf32>
+// CHECK-SAME: %[[PROD_INIT:[0-9a-z]*]]: tensor<1x6x6xf32>
+// CHECK-SAME: %[[FOR_IV_INIT:[0-9a-z]*]]: tensor<4x6xf32>
+// CHECK-SAME: %[[CONS_INIT:[0-9a-z]*]]: tensor<4x2xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
@@ -372,22 +377,22 @@ func.func @rank_reduced_extract_slice(
// For loop right after tensor alloc & fill, no linalg.generic.
// CHECK-NOT: linalg.generic
-// CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[ARG4]])
+// CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[FOR_IV_INIT]])
// Producer linalg.generic now inside the loop, with tiled args sliced before
// it.
-// CHECK-DAG: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1] : tensor<1x5x6xf32> to tensor<1x5x2xf32>
-// CHECK-DAG: %[[PROD_SLICE:.*]] = tensor.extract_slice %[[ARG3]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32>
+// CHECK-DAG: %[[PROD_WEIGHT_SLICE:.*]] = tensor.extract_slice %[[PROD_WEIGHT]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1] : tensor<1x5x6xf32> to tensor<1x5x2xf32>
+// CHECK-DAG: %[[PROD_INIT_SLICE:.*]] = tensor.extract_slice %[[PROD_INIT]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32>
// CHECK: %[[MMUL_PROD:.*]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]], %[[ARG1_SLICE]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>)
-// CHECK-SAME: outs(%[[PROD_SLICE]] : tensor<1x6x2xf32>)
+// CHECK-SAME: ins(%[[PROD_IN]], %[[PROD_WEIGHT_SLICE]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>)
+// CHECK-SAME: outs(%[[PROD_INIT_SLICE]] : tensor<1x6x2xf32>)
//
// Consumer uses a rank-reduced version of producer result so a collapse_shape
// is generated.
// CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0, 1\], \[2\]\]}} : tensor<1x6x2xf32> into tensor<6x2xf32>
// CHECK: %[[MMUL_CONS:.*]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
-// CHECK-SAME: outs(%[[ARG5]] : tensor<4x2xf32>)
+// CHECK-SAME: ins(%[[CONS_IN]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
+// CHECK-SAME: outs(%[[CONS_INIT]] : tensor<4x2xf32>)
// CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
// CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32>
// CHECK: return %[[FOR]] : tensor<4x6xf32>
More information about the Mlir-commits
mailing list