[Mlir-commits] [mlir] 2e12bad - [MLIR][Linalg] Fix insert_slice fusion with rank reduction (#130961)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 23 01:32:42 PDT 2025


Author: Thomas Preud'homme
Date: 2025-05-23T09:32:39+01:00
New Revision: 2e12badc941282f041c91f5372af29e35c59e4a2

URL: https://github.com/llvm/llvm-project/commit/2e12badc941282f041c91f5372af29e35c59e4a2
DIFF: https://github.com/llvm/llvm-project/commit/2e12badc941282f041c91f5372af29e35c59e4a2.diff

LOG: [MLIR][Linalg] Fix insert_slice fusion with rank reduction (#130961)

Insert_slice fusion with a linalg producer does not account for
possible rank-reduction in the insert_slice return type. When that
happens, a tensor.cast gets generated due to the type mismatch which is
invalid for tensor with different rank. This later trips other pass.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Tensor/Utils/Utils.cpp
    mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index 22ca8a99dd7db..1a4733df3f187 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -43,6 +43,11 @@ FailureOr<RankedTensorType>
 computeTransposedType(RankedTensorType rankedTensorType,
                       ArrayRef<int64_t> transposeVector);
 
+/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor
+/// `src`.
+CollapseShapeOp dropGivenUnitDims(OpBuilder &b, Location loc, Value src,
+                                  const llvm::SmallBitVector &dropDims);
+
 /// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
 /// source tensor or inserts the source tensor into a destination tensor with
 /// the same shape.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index fcfb499bb1332..4fc8a17554435 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Dominance.h"
@@ -26,6 +27,7 @@
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallBitVector.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 
@@ -271,12 +273,20 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
            consumerOpOperand);
 
   // Replace use.
+  Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
+  Type consumerType = consumerOpOperand.get().getType();
+  // Check if rank-reduction occurred as part of the extract_slice. If yes,
+  // collapse the dropped dimensions.
+  if (cast<ShapedType>(consumerType).getRank() !=
+      cast<ShapedType>(def.getType()).getRank()) {
+    llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
+    def =
+        tensor::dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims);
+  }
   // Canonicalizations are not guaranteed to have happened before constructing
   // `fusedProducer`. In the tensor case this can result in temporary type
   // mismatches. Insert a `tensor.cast` op to propagate the transformation
   // invariant that types are compatible.
-  Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
-  Type consumerType = consumerOpOperand.get().getType();
   if (consumerType != def.getType())
     def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
   consumerOpOperand.set(def);

diff  --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index c3d56759a896a..11ae0108594dd 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -94,6 +94,37 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
   return transposedTensorType;
 }
 
+CollapseShapeOp
+mlir::tensor::dropGivenUnitDims(OpBuilder &b, Location loc, Value src,
+                                const llvm::SmallBitVector &dropDims) {
+  auto srcType = cast<ShapedType>(src.getType());
+  int64_t rank = srcType.getRank();
+  assert(rank == static_cast<int64_t>(dropDims.size()) &&
+         "dropDims dimension does not match src tensor rank");
+  assert(llvm::all_of(
+             dropDims.set_bits(),
+             [&](unsigned dim) { return srcType.getShape()[dim] == 1; }) &&
+         "Dropping non unit dimension");
+  // Computed reassociation map for the corresponding tensor.collapse_shape.
+  SmallVector<ReassociationIndices, 2> reassocMaps;
+  // Current reassociation group to add dropped dimension to.
+
+  int64_t nextDimToGroup = 0;
+  llvm::SmallBitVector keptDims(dropDims);
+  keptDims.flip();
+  int64_t lastSetBit = keptDims.find_last();
+  for (int64_t setBit : keptDims.set_bits()) {
+    // Group consecutive dropped dimension with the next non-dropped dimension.
+    // If this is the last set dimension, also group all subsequent dropped
+    // dimension, if any.
+    int64_t upTo = setBit == lastSetBit ? rank - 1 : setBit;
+    auto seq = llvm::seq_inclusive(nextDimToGroup, upTo);
+    reassocMaps.emplace_back(llvm::make_range(seq.begin(), seq.end()));
+    nextDimToGroup = setBit + 1;
+  }
+  return b.create<tensor::CollapseShapeOp>(loc, src, reassocMaps);
+}
+
 bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
   llvm::SmallBitVector droppedDims = op.getDroppedDims();
   int64_t srcDim = 0;

diff  --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index 0f27a92c119cf..fd755a208b2c9 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -318,3 +318,81 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
   }
   return %for0 : tensor<64x128xf32>
 }
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map4 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @rank_reduced_extract_slice(
+    %prod_in: tensor<1x6x5xf32>, %prod_weight: tensor<1x5x6xf32>,
+    %cons_in: tensor<4x6xf32>, %prod_init: tensor<1x6x6xf32>,
+    %for_iv_init: tensor<4x6xf32>, %cons_init: tensor<4x2xf32>
+) -> tensor<4x6xf32> {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c6 = arith.constant 6 : index
+  %mmul_prod = linalg.generic
+    {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+    ins(%prod_in, %prod_weight : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%prod_init : tensor<1x6x6xf32>) {
+  ^bb0(%in: f32, %in_1: f32, %out: f32):
+    %10 = arith.mulf %in, %in_1 : f32
+    %11 = arith.addf %out, %10 : f32
+    linalg.yield %11 : f32
+  } -> tensor<1x6x6xf32>
+  %for = scf.for %arg7 = %c0 to %c6 step %c2 iter_args(%arg6 = %for_iv_init) -> (tensor<4x6xf32>) {
+
+    // Extract slice with rank-reduced result type. When fused in the loop
+    // with sliced operands, the producer linalg must have its now sliced
+    // result be rank-reduced as well to match consumer's use type.
+    %prod_slice = tensor.extract_slice %mmul_prod[0, 0, %arg7] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32>
+    %mmul_cons = linalg.generic
+     {indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"]}
+     ins(%cons_in, %prod_slice : tensor<4x6xf32>, tensor<6x2xf32>) outs(%cons_init : tensor<4x2xf32>) {
+    ^bb0(%in: f32, %in_1: f32, %out: f32):
+      %20 = arith.mulf %in, %in_1 : f32
+      %21 = arith.addf %out, %20 : f32
+      linalg.yield %21 : f32
+    } -> tensor<4x2xf32>
+    %4 = tensor.insert_slice %mmul_cons into %arg6[0, %arg7] [4, 2] [1, 1]  : tensor<4x2xf32> into tensor<4x6xf32>
+    scf.yield %4 : tensor<4x6xf32>
+  }
+  return %for : tensor<4x6xf32>
+}
+
+//       CHECK: func @rank_reduced_extract_slice(
+//  CHECK-SAME: %[[PROD_IN:[0-9a-z]*]]: tensor<1x6x5xf32>
+//  CHECK-SAME: %[[PROD_WEIGHT:[0-9a-z]*]]: tensor<1x5x6xf32>
+//  CHECK-SAME: %[[CONS_IN:[0-9a-z]*]]: tensor<4x6xf32>
+//  CHECK-SAME: %[[PROD_INIT:[0-9a-z]*]]: tensor<1x6x6xf32>
+//  CHECK-SAME: %[[FOR_IV_INIT:[0-9a-z]*]]: tensor<4x6xf32>
+//  CHECK-SAME: %[[CONS_INIT:[0-9a-z]*]]: tensor<4x2xf32>
+
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+
+//  For loop right after tensor alloc & fill, no linalg.generic.
+//   CHECK-NOT: linalg.generic
+//  CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[FOR_IV_INIT]])
+
+//  Producer linalg.generic now inside the loop, with tiled args sliced before
+//  it.
+//   CHECK-DAG:   %[[PROD_WEIGHT_SLICE:.*]] = tensor.extract_slice %[[PROD_WEIGHT]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1]  : tensor<1x5x6xf32> to tensor<1x5x2xf32>
+//   CHECK-DAG:   %[[PROD_INIT_SLICE:.*]] = tensor.extract_slice %[[PROD_INIT]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1]  : tensor<1x6x6xf32> to tensor<1x6x2xf32>
+//       CHECK:    %[[MMUL_PROD:.*]] = linalg.generic
+//  CHECK-SAME:        ins(%[[PROD_IN]], %[[PROD_WEIGHT_SLICE]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>)
+//  CHECK-SAME:        outs(%[[PROD_INIT_SLICE]] : tensor<1x6x2xf32>)
+//
+//  Consumer uses a rank-reduced version of producer result so a collapse_shape
+//  is generated.
+//       CHECK:    %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0, 1\], \[2\]\]}} : tensor<1x6x2xf32> into tensor<6x2xf32>
+//       CHECK:    %[[MMUL_CONS:.*]] = linalg.generic
+//  CHECK-SAME:        ins(%[[CONS_IN]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
+//  CHECK-SAME:        outs(%[[CONS_INIT]] : tensor<4x2xf32>)
+//       CHECK:   %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
+//       CHECK:   scf.yield %[[CONS_SLICE]] : tensor<4x6xf32>
+//       CHECK: return %[[FOR]] : tensor<4x6xf32>


        


More information about the Mlir-commits mailing list