[Mlir-commits] [mlir] [MLIR][Linalg] Fix insert_slice fusion with rank reduction (PR #130961)

Thomas Preud'homme llvmlistbot at llvm.org
Tue Mar 25 17:01:34 PDT 2025


https://github.com/RoboTux updated https://github.com/llvm/llvm-project/pull/130961

>From a370cd2d44b2715470c49dfb8b013d12dcff9826 Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Wed, 12 Mar 2025 13:22:14 +0000
Subject: [PATCH 1/2] [MLIR][Linalg] Fix insert_slice fusion with rank
 reduction

Insert_slice fusion with a linalg producer does not account for
possible rank-reduction in the insert_slice return type. When that
happens, a tosa.cast gets generated due to the type mismatch which is
invalid for tensor with different rank. This later trips other pass.
---
 mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 36 ++++++++++-
 .../Dialect/Linalg/tile-and-fuse-tensors.mlir | 63 +++++++++++++++++++
 2 files changed, 97 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 223d728b0b27d..81b204df5a0aa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Dominance.h"
@@ -26,6 +27,7 @@
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallBitVector.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 
@@ -235,6 +237,31 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
   return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
 }
 
+/// Create tensor.collapse_shape to drop dimensions in `dropDims` in tensor
+/// `from`.
+tensor::CollapseShapeOp collapseTo(OpBuilder &b, Location loc, Value from,
+                                   const llvm::SmallBitVector &dropDims) {
+  auto fromType = cast<ShapedType>(from.getType());
+  assert(fromType.getRank() == dropDims.size());
+  SmallVector<ReassociationIndices, 2> reassocIdxsVec;
+  ReassociationIndices reassocIdxs;
+
+  bool foundKeptDim = false;
+  for (int dim = 0; dim < fromType.getRank(); dim++) {
+    if (!dropDims.test(dim)) {
+      if (foundKeptDim) {
+        reassocIdxsVec.push_back(reassocIdxs);
+        reassocIdxs.clear();
+      }
+      foundKeptDim = true;
+    }
+    reassocIdxs.push_back(dim);
+  }
+  if (!reassocIdxs.empty())
+    reassocIdxsVec.push_back(reassocIdxs);
+  return b.create<tensor::CollapseShapeOp>(loc, from, reassocIdxsVec);
+}
+
 FailureOr<FusionInfo>
 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
                                    OpOperand &consumerOpOperand) {
@@ -255,6 +282,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
                << "\nNot fusable, not an extract_slice op: " << inputTensor);
     return failure();
   }
+  llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
 
   // If producer is already in the same block as consumer, we are done.
   if (consumerOpOperand.get().getParentBlock() ==
@@ -272,12 +300,16 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
            consumerOpOperand);
 
   // Replace use.
+  Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
+  Type consumerType = consumerOpOperand.get().getType();
+  // Rank-reduction occured as part of the extract_slice.
+  if (cast<ShapedType>(consumerType).getRank() !=
+      cast<ShapedType>(def.getType()).getRank())
+    def = collapseTo(b, fusedProducer.getLoc(), def, droppedDims);
   // Canonicalizations are not guaranteed to have happened before constructing
   // `fusedProducer`. In the tensor case this can result in temporary type
   // mismatches. Insert a `tensor.cast` op to propagate the transformation
   // invariant that types are compatible.
-  Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
-  Type consumerType = consumerOpOperand.get().getType();
   if (consumerType != def.getType())
     def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
   consumerOpOperand.set(def);
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index 0f27a92c119cf..b4fbdfacde899 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -318,3 +318,66 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
   }
   return %for0 : tensor<64x128xf32>
 }
+
+// -----
+
+func.func @rank_reduced_extract_slice(%arg0: tensor<6x6x1x1x1x1xf32>, %arg1: tensor<6x6x1x1xf32>, %arg2: tensor<4x6xf32>) -> tensor<4x6xf32> {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c6 = arith.constant 6 : index
+  %cst = arith.constant 0.0 : f32
+  %init1 = tensor.empty() : tensor<6x6x1x1x1x1xf32>
+  %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d6, d5)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<6x6x1x1x1x1xf32>, tensor<6x6x1x1xf32>) outs(%fill1 : tensor<6x6x1x1x1x1xf32>) {
+  ^bb0(%in: f32, %in_1: f32, %out: f32):
+    %10 = arith.mulf %in, %in_1 : f32
+    %11 = arith.addf %out, %10 : f32
+    linalg.yield %11 : f32
+  } -> tensor<6x6x1x1x1x1xf32>
+  %init2 = tensor.empty() : tensor<4x6xf32>
+  %1 = scf.for %arg4 = %c0 to %c6 step %c2 iter_args(%arg3 = %init2) -> (tensor<4x6xf32>) {
+    %2 = tensor.extract_slice %0[0, %arg4, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32>
+    %init3 = tensor.empty() : tensor<4x2xf32>
+    %fill3 = linalg.fill ins(%cst : f32) outs(%init3 : tensor<4x2xf32>) -> tensor<4x2xf32>
+    %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%fill3 : tensor<4x2xf32>) {
+    ^bb0(%in: f32, %in_1: f32, %out: f32):
+      %20 = arith.mulf %in, %in_1 : f32
+      %21 = arith.addf %out, %20 : f32
+      linalg.yield %21 : f32
+    } -> tensor<4x2xf32>
+    %4 = tensor.insert_slice %3 into %arg3[0, %arg4] [4, 2] [1, 1]  : tensor<4x2xf32> into tensor<4x6xf32>
+    scf.yield %4 : tensor<4x6xf32>
+  }
+  return %1 : tensor<4x6xf32>
+}
+
+//       CHECK: func @rank_reduced_extract_slice(
+//  CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<6x6x1x1x1x1xf32>
+//  CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<6x6x1x1xf32>
+//  CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32>
+
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+//       CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32>
+//       CHECK: %[[FILL_PROD:.*]] = linalg.fill ins({{%.*}} : f32)
+//  CHECK-SAME:     outs(%[[EMPTY_PROD]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
+//       CHECK: %[[EMPTY_FOR:.*]] = tensor.empty() : tensor<4x6xf32>
+//       CHECK: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<4x2xf32>
+//       CHECK: %[[FILL_CONS:.*]] = linalg.fill ins({{%.*}} : f32)
+//  CHECK-SAME:     outs(%[[EMPTY_CONS]] : tensor<4x2xf32>) -> tensor<4x2xf32>
+//       CHECK: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[EMPTY_FOR]])
+//   CHECK-DAG:   %[[ARG0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[I]], 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]  : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32>
+//   CHECK-DAG:   %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[I]], 0, 0] [6, 2, 1, 1] [1, 1, 1, 1]  : tensor<6x6x1x1xf32> to tensor<6x2x1x1xf32>
+//   CHECK-DAG:   %[[FILL_PROD_SLICE:.*]] = tensor.extract_slice %[[FILL_PROD]][0, %[[I]], 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32>
+
+//       CHECK:    %[[MMUL_PROD:.*]] = linalg.generic
+//  CHECK-SAME:        ins(%[[ARG0_SLICE]], %[[ARG1_SLICE]] : tensor<6x2x1x1x1x1xf32>, tensor<6x2x1x1xf32>)
+//  CHECK-SAME:        outs(%[[FILL_PROD_SLICE]] : tensor<6x2x1x1x1x1xf32>)
+//       CHECK:    %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0\], \[1, 2, 3, 4, 5\]\]}} : tensor<6x2x1x1x1x1xf32> into tensor<6x2xf32>
+//       CHECK:    %[[MMUL_CONS:.*]] = linalg.generic
+//  CHECK-SAME:        ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
+//  CHECK-SAME:        outs(%[[FILL_CONS]] : tensor<4x2xf32>)
+//       CHECK:   %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
+//       CHECK:   scf.yield %[[CONS_SLICE]] : tensor<4x6xf32>
+//       CHECK: return %[[FOR]] : tensor<4x6xf32>

>From ce067327de33ee88397573ead73291764f15c627 Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Tue, 25 Mar 2025 22:49:26 +0000
Subject: [PATCH 2/2] Add more comments and simplify test

---
 mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 12 ++-
 .../Dialect/Linalg/tile-and-fuse-tensors.mlir | 88 ++++++++-----------
 2 files changed, 45 insertions(+), 55 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 81b204df5a0aa..d18d6f7ff8dd8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -239,14 +239,20 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
 
 /// Create tensor.collapse_shape to drop dimensions in `dropDims` in tensor
 /// `from`.
-tensor::CollapseShapeOp collapseTo(OpBuilder &b, Location loc, Value from,
-                                   const llvm::SmallBitVector &dropDims) {
+static tensor::CollapseShapeOp collapseTo(OpBuilder &b, Location loc, Value from,
+                                          const llvm::SmallBitVector &dropDims) {
   auto fromType = cast<ShapedType>(from.getType());
-  assert(fromType.getRank() == dropDims.size());
+  assert(fromType.getRank() == dropDims.size() && "dropDims dimension does not match from tensor rank");
+  // Computed reassociation map for the corresponding tensor.collapse_shape.
   SmallVector<ReassociationIndices, 2> reassocIdxsVec;
+  // Current reassociation indices to add dropped dimension to.
   ReassociationIndices reassocIdxs;
 
   bool foundKeptDim = false;
+  // Dropped dimensions might be at the beginning or end of the shape so
+  // combine all contiguous dimensions before and after a given non dropped
+  // dimension in reassocIdxs until another non dropped dimension is found.
+  // When that happens, add the reassociation indices to the map.
   for (int dim = 0; dim < fromType.getRank(); dim++) {
     if (!dropDims.test(dim)) {
       if (foundKeptDim) {
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index b4fbdfacde899..46b70a9c0edba 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -321,63 +321,47 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
 
 // -----
 
-func.func @rank_reduced_extract_slice(%arg0: tensor<6x6x1x1x1x1xf32>, %arg1: tensor<6x6x1x1xf32>, %arg2: tensor<4x6xf32>) -> tensor<4x6xf32> {
-  %c0 = arith.constant 0 : index
-  %c2 = arith.constant 2 : index
-  %c6 = arith.constant 6 : index
+func.func @rank_reduced_extract_slice(%cond : i1) -> tensor<6x2xf32> {
   %cst = arith.constant 0.0 : f32
-  %init1 = tensor.empty() : tensor<6x6x1x1x1x1xf32>
-  %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
-  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d6, d5)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<6x6x1x1x1x1xf32>, tensor<6x6x1x1xf32>) outs(%fill1 : tensor<6x6x1x1x1x1xf32>) {
-  ^bb0(%in: f32, %in_1: f32, %out: f32):
-    %10 = arith.mulf %in, %in_1 : f32
-    %11 = arith.addf %out, %10 : f32
-    linalg.yield %11 : f32
+  %cst1 = arith.constant 1.0 : f32
+
+  %empty1 = tensor.empty() : tensor<6x6x1x1x1x1xf32>
+  %init1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} outs(%empty1 : tensor<6x6x1x1x1x1xf32>) {
+  ^bb0(%out: f32):
+    linalg.yield %cst : f32
   } -> tensor<6x6x1x1x1x1xf32>
-  %init2 = tensor.empty() : tensor<4x6xf32>
-  %1 = scf.for %arg4 = %c0 to %c6 step %c2 iter_args(%arg3 = %init2) -> (tensor<4x6xf32>) {
-    %2 = tensor.extract_slice %0[0, %arg4, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32>
-    %init3 = tensor.empty() : tensor<4x2xf32>
-    %fill3 = linalg.fill ins(%cst : f32) outs(%init3 : tensor<4x2xf32>) -> tensor<4x2xf32>
-    %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%fill3 : tensor<4x2xf32>) {
-    ^bb0(%in: f32, %in_1: f32, %out: f32):
-      %20 = arith.mulf %in, %in_1 : f32
-      %21 = arith.addf %out, %20 : f32
-      linalg.yield %21 : f32
-    } -> tensor<4x2xf32>
-    %4 = tensor.insert_slice %3 into %arg3[0, %arg4] [4, 2] [1, 1]  : tensor<4x2xf32> into tensor<4x6xf32>
-    scf.yield %4 : tensor<4x6xf32>
+
+  %if = scf.if %cond -> tensor<6x2xf32> {
+    %extract0 = tensor.extract_slice %init1[0, 0, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32>
+
+    %init2 = tensor.empty() : tensor<6x2xf32>
+    %add1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extract0 : tensor<6x2xf32>) outs(%init2 : tensor<6x2xf32>) {
+      ^bb0(%in: f32, %out: f32):
+        %add = arith.addf %in, %cst1 : f32
+        linalg.yield %add : f32
+    } -> tensor<6x2xf32>
+    scf.yield %add1 : tensor<6x2xf32>
+  } else {
+    %extract2 = tensor.extract_slice %init1[0, 2, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32>
+    scf.yield %extract2 : tensor<6x2xf32>
   }
-  return %1 : tensor<4x6xf32>
+
+  return %if : tensor<6x2xf32>
 }
 
 //       CHECK: func @rank_reduced_extract_slice(
-//  CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<6x6x1x1x1x1xf32>
-//  CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<6x6x1x1xf32>
-//  CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32>
+//  CHECK-SAME: %[[COND:[0-9a-z]*]]: i1
 
-//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-//   CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
 //       CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32>
-//       CHECK: %[[FILL_PROD:.*]] = linalg.fill ins({{%.*}} : f32)
-//  CHECK-SAME:     outs(%[[EMPTY_PROD]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
-//       CHECK: %[[EMPTY_FOR:.*]] = tensor.empty() : tensor<4x6xf32>
-//       CHECK: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<4x2xf32>
-//       CHECK: %[[FILL_CONS:.*]] = linalg.fill ins({{%.*}} : f32)
-//  CHECK-SAME:     outs(%[[EMPTY_CONS]] : tensor<4x2xf32>) -> tensor<4x2xf32>
-//       CHECK: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[EMPTY_FOR]])
-//   CHECK-DAG:   %[[ARG0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[I]], 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]  : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32>
-//   CHECK-DAG:   %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[I]], 0, 0] [6, 2, 1, 1] [1, 1, 1, 1]  : tensor<6x6x1x1xf32> to tensor<6x2x1x1xf32>
-//   CHECK-DAG:   %[[FILL_PROD_SLICE:.*]] = tensor.extract_slice %[[FILL_PROD]][0, %[[I]], 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32>
-
-//       CHECK:    %[[MMUL_PROD:.*]] = linalg.generic
-//  CHECK-SAME:        ins(%[[ARG0_SLICE]], %[[ARG1_SLICE]] : tensor<6x2x1x1x1x1xf32>, tensor<6x2x1x1xf32>)
-//  CHECK-SAME:        outs(%[[FILL_PROD_SLICE]] : tensor<6x2x1x1x1x1xf32>)
-//       CHECK:    %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0\], \[1, 2, 3, 4, 5\]\]}} : tensor<6x2x1x1x1x1xf32> into tensor<6x2xf32>
-//       CHECK:    %[[MMUL_CONS:.*]] = linalg.generic
-//  CHECK-SAME:        ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
-//  CHECK-SAME:        outs(%[[FILL_CONS]] : tensor<4x2xf32>)
-//       CHECK:   %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
-//       CHECK:   scf.yield %[[CONS_SLICE]] : tensor<4x6xf32>
-//       CHECK: return %[[FOR]] : tensor<4x6xf32>
+//       CHECK: %[[FILL_PROD:.*]] = linalg.generic
+//  CHECK-SAME:     outs(%[[EMPTY_PROD]] : tensor<6x6x1x1x1x1xf32>)
+
+//       CHECK: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<6x2xf32>
+//       CHECK: %[[EXTRACT_SLICE_CONS:.*]] = tensor.extract_slice %[[EMPTY_PROD]][0, 0, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32>
+
+//       CHECK: %[[FILL_CONS:.*]] = linalg.generic
+//  CHECK-SAME:     outs(%[[EXTRACT_SLICE_CONS]] : tensor<6x2x1x1x1x1xf32>)
+//       CHECK: %[[CONS_COLLAPSE:.*]] = tensor.collapse_shape %[[FILL_CONS]] {{\[\[0\], \[1, 2, 3, 4, 5\]\]}} : tensor<6x2x1x1x1x1xf32> into tensor<6x2xf32>
+//       CHECK: %[[ADD1_CONS:.*]] = linalg.generic
+//  CHECK-SAME:     ins(%[[CONS_COLLAPSE]] : tensor<6x2xf32>)
+//  CHECK-SAME:     outs(%[[EMPTY_CONS]] : tensor<6x2xf32>)



More information about the Mlir-commits mailing list