[Mlir-commits] [mlir] [mlir][Linalg] enable scalar lowering for linalg.pack (PR #178222)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 27 07:05:20 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Ryutaro Okada (sakupan102)

<details>
<summary>Changes</summary>

This change is part of an effort to upstream IREE’s code to MLIR. It implements generateScalarImplementation for linalg.pack so that pack can lower to scalar loop code.

---

Patch is 33.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/178222.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/Utils/Utils.h (+17) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+8-4) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+136-2) 
- (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+26) 
- (modified) mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir (+400) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 06e7a472a8182..c1f9db1e3530f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -44,6 +44,23 @@ SmallVector<int64_t> getPackInverseDestPerm(linalg::PackOp packOp,
 SmallVector<int64_t> getUnPackInverseSrcPerm(linalg::UnPackOp,
                                              PackingMetadata &metadata);
 
+/// Returns a vector that interchanges `elements` starting at offset `offset`
+/// based on the indexes in `interchangeVector`.
+template <typename T>
+SmallVector<T> interchange(ArrayRef<T> elements,
+                           ArrayRef<int64_t> interchangeVector,
+                           int offset = 0) {
+  SmallVector<T> vec = llvm::to_vector(elements);
+  for (auto [idx, val] : llvm::enumerate(interchangeVector)) {
+    vec[idx + offset] = elements[val + offset];
+  }
+  return vec;
+}
+
+/// Returns the `interchangeVector` based on `dimsPos`.
+SmallVector<int64_t> computeInterchangeFromDimPos(ArrayRef<int64_t> dimsPos,
+                                                  int64_t rank);
+
 //===----------------------------------------------------------------------===//
 // General utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 40cabe20d1a4b..1162b96a67fc7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5035,8 +5035,14 @@ reifyResultShapesImpl(OpTy op, OpBuilder &builder,
                 "applies to only pack or unpack operations");
   int64_t destRank = op.getDestRank();
   reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
-  reifiedReturnShapes[0] =
-      tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
+  if (op.hasPureTensorSemantics()) {
+    reifiedReturnShapes[0] =
+        tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
+  }
+  if (op.hasPureBufferSemantics()) {
+    reifiedReturnShapes[0] =
+        memref::getMixedSizes(builder, op.getLoc(), op.getDest());
+  }
   return success();
 }
 
@@ -5434,8 +5440,6 @@ void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
 LogicalResult
 PackOp::reifyResultShapes(OpBuilder &builder,
                           ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
-  if (!hasPureTensorSemantics())
-    return failure();
   return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 5d39c4731dd1b..3ab6e741c7366 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -80,6 +80,90 @@ static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp,
   return success();
 }
 
+/// Generate the body of the innermost loop of the scalar implementation
+/// of `pack` operation.
+static void generatePackOpScalarImplementationBody(PackOp packOp,
+                                                   OpBuilder &builder,
+                                                   Location loc,
+                                                   ValueRange ivs) {
+  // Note: `ivs` are already in the correct order, possibly interchanged based
+  // on `dims_pos`. However, connecting the loops with the access patterns is
+  // difficult - What is the relation between the position of the tile loop
+  // and the point loop? However, if we interchange `ivs` once more to go to
+  // the canonical blocking format: ABCabc, this connection becomes trivial:
+  // Each point loop is pointLoopsOffset + inputRank away from the tiled loop.
+  ArrayRef<int64_t> dimsToInnerBlock = packOp.getInnerDimsPos();
+  ArrayRef<int64_t> dimsToOuterBlock = packOp.getOuterDimsPerm();
+
+  SmallVector<Value> interchangedIvs = ivs;
+  SmallVector<int64_t> interchangeVector =
+      computeInterchangeFromDimPos(dimsToInnerBlock, packOp.getSourceRank());
+  interchangedIvs = interchange<Value>(interchangedIvs, interchangeVector,
+                                       /*offset=*/packOp.getSourceRank());
+  if (!dimsToOuterBlock.empty()) {
+    interchangeVector =
+        computeInterchangeFromDimPos(dimsToOuterBlock, packOp.getSourceRank());
+    interchangedIvs =
+        interchange<Value>(interchangedIvs, interchangeVector, /*offset=*/0);
+  }
+
+  SmallVector<OpFoldResult> tiles = packOp.getMixedTiles();
+  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+      packOp.getDimAndTileMapping();
+  SmallVector<OpFoldResult> sourceIndices;
+  size_t pointLoopsOffset = 0;
+  int64_t inputRank = packOp.getSourceRank();
+  for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
+    if (dimAndTileMapping.count(dim)) {
+      AffineExpr i, j, tile;
+      bindDims(builder.getContext(), i, j);
+      bindSymbols(builder.getContext(), tile);
+      OpFoldResult sourceIndex = affine::makeComposedFoldedAffineApply(
+          builder, loc, i * tile + j,
+          ArrayRef<OpFoldResult>{
+              interchangedIvs[dim],
+              interchangedIvs[pointLoopsOffset + packOp.getSourceRank()],
+              dimAndTileMapping[dim]});
+      sourceIndices.push_back(sourceIndex);
+      ++pointLoopsOffset;
+    } else {
+      sourceIndices.push_back(interchangedIvs[dim]);
+    }
+  }
+
+  auto createLoad = [&]() -> Value {
+    return memref::LoadOp::create(
+        builder, loc, packOp.getSource(),
+        getValueOrCreateConstantIndexOp(builder, loc, sourceIndices));
+  };
+  Value scalar;
+  if (auto paddingValue = packOp.getPaddingValue()) {
+    ArithBuilder arithBuilder(builder, loc);
+    Value isInBounds;
+    for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
+      Value idx =
+          getValueOrCreateConstantIndexOp(builder, loc, sourceIndices[dim]);
+      Value cond = arithBuilder.slt(
+          idx, createOrFoldDimOp(builder, loc, packOp.getSource(), dim));
+      isInBounds = dim == 0 ? cond : arithBuilder._and(isInBounds, cond);
+    }
+    scalar = scf::IfOp::create(
+                 builder, loc, isInBounds, /*thenBuilder=*/
+                 [&](OpBuilder &b, Location l) {
+                   scf::YieldOp::create(b, l, createLoad());
+                 },
+                 /*elseBuilder=*/
+                 [&](OpBuilder &b, Location l) {
+                   scf::YieldOp::create(b, l, paddingValue);
+                 })
+                 .getResult(0);
+  } else {
+    scalar = createLoad();
+  }
+
+  memref::StoreOp::create(builder, loc, scalar, packOp.getDest(), ivs);
+}
+
 //===----------------------------------------------------------------------===//
 // External Model for implementing `TilingInterface` for `LinalgOp`s.
 //===----------------------------------------------------------------------===//
@@ -725,7 +809,7 @@ static SmallVector<Range> getPackUnPackIterationDomain(OpTy op,
   OpFoldResult zero = builder.getIndexAttr(0);
   OpFoldResult one = builder.getIndexAttr(1);
   ReifiedRankedShapedTypeDims resultShape;
-  (void)reifyResultShapes(builder, op, resultShape);
+  (void)op.reifyResultShapes(builder, resultShape);
   SmallVector<Range> loopBounds(rank);
   for (auto dim : llvm::seq<int64_t>(0, rank)) {
     loopBounds[dim].offset = zero;
@@ -865,7 +949,9 @@ struct PackOpTiling
     resultOffsets.append(outputRank - inputRank, zeroAttr);
 
     ReifiedRankedShapedTypeDims outputShape;
-    (void)reifyResultShapes(b, packOp, outputShape);
+    if (failed(packOp.reifyResultShapes(b, outputShape))) {
+      return packOp.getOperation()->emitOpError("failed to reify result shape");
+    }
     resultSizes.assign(sizes.begin(), sizes.end());
     for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
       resultSizes.push_back(outputShape[0][dataTileDim]);
@@ -1058,6 +1144,54 @@ struct PackOpTiling
         SmallVector<Value>(tiledPackOp->getResults()),
         llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
   }
+
+  LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
+                                             Location loc,
+                                             ValueRange ivs) const {
+    OpBuilder::InsertionGuard g(builder);
+    auto packOp = cast<PackOp>(op);
+    // The `ivs` already represent the position into the output tensor for the
+    // non data-tile dimensions.
+    SmallVector<Value> ivVec = llvm::to_vector(ivs);
+    ReifiedRankedShapedTypeDims outputShape;
+    if (failed(packOp.reifyResultShapes(builder, outputShape))) {
+      return packOp.getOperation()->emitOpError("failed to reify result shape");
+    }
+    if (outputShape.size() != 1 ||
+        outputShape[0].size() != packOp.getDestRank()) {
+      return packOp.getOperation()->emitOpError(
+                 "expected shape of one result value of rank")
+             << packOp.getDestRank();
+    }
+
+    // Generate the loops that iterate over the data tile.
+    Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
+    Value one = arith::ConstantIndexOp::create(builder, loc, 1);
+
+    // All loops except the innermost are simple loops that just iterate
+    // over the tile dimensions.
+    for (auto dataTileDim : llvm::seq<unsigned>(packOp.getSourceRank(),
+                                                packOp.getDestRank() - 1)) {
+      Value ub = getValueOrCreateConstantIndexOp(builder, loc,
+                                                 outputShape[0][dataTileDim]);
+      scf::ForOp loop = scf::ForOp::create(builder, loc, zero, ub, one);
+      builder.setInsertionPointToStart(loop.getBody());
+      ivVec.push_back(loop.getInductionVar());
+    }
+    // The body of the innermost loops does the actual data movement.
+    scf::ForOp::create(
+        builder, loc, zero,
+        getValueOrCreateConstantIndexOp(builder, loc, outputShape[0].back()),
+        one, ValueRange{},
+        [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
+            ValueRange regionIterArgs) {
+          ivVec.push_back(iv);
+          generatePackOpScalarImplementationBody(packOp, bodyBuilder, bodyLoc,
+                                                 ivVec);
+          scf::YieldOp::create(bodyBuilder, bodyLoc);
+        });
+    return success();
+  }
 };
 
 struct UnpackTileDimInfo {
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index a1ee6b307caf5..c922ee4b7b6f0 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -192,6 +192,32 @@ SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp,
   return unpackInvSrcPerm;
 }
 
+SmallVector<int64_t> computeInterchangeFromDimPos(ArrayRef<int64_t> dimsPos,
+                                                  int64_t rank) {
+  SmallVector<int64_t> interchangeVector;
+  interchangeVector.reserve(dimsPos.size());
+  // First map dims and their position. For example, dims_pos = [2, 0] will
+  // map to:
+  // [
+  //  [ key: 2, value: 0]
+  //  [ key: 0, value: 1]
+  // ]
+  // where key is the idx in dims_pos while value its position in dims_pos.
+  DenseMap<int64_t, int64_t> dimsAndPosMapping;
+  for (int64_t dimsIdx = 0, end = dimsPos.size(); dimsIdx < end; dimsIdx++) {
+    dimsAndPosMapping[dimsPos[dimsIdx]] = dimsIdx;
+  }
+
+  // Scan the position in order and insert the value in the map
+  // to compute the interchange vector.
+  for (int64_t dimsIdx = 0; dimsIdx < rank; dimsIdx++) {
+    if (dimsAndPosMapping.count(dimsIdx)) {
+      interchangeVector.push_back(dimsAndPosMapping[dimsIdx]);
+    }
+  }
+  return interchangeVector;
+}
+
 bool allIndexingsAreProjectedPermutation(LinalgOp op) {
   return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) {
     return m.isProjectedPermutation(/*allowZeroInResults=*/true);
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index aa8882d21698c..50c21ae721ab3 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -400,3 +400,403 @@ module attributes {transform.with_named_sequence} {
 // CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
 // CHECK:           %[[ELEM:.*]] = memref.load %[[IN]][%[[I]], %[[K]]]
 // CHECK:           memref.store %[[ELEM]], %[[OUT]][%[[I]], %[[J]], %[[K]]]
+
+// -----
+
+func.func @NC_to_NCnc(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) {
+  linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg1 : memref<128x256xf32> -> memref<4x8x32x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %broadcast
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK:       #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK:       func.func @NC_to_NCnc(
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C32:.*]] = arith.constant 32 : index
+// CHECK:         scf.for %[[N:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK:           scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:             scf.for %[[n:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK:               scf.for %[[c:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK-DAG:             %[[applyMapI:.*]] = affine.apply #[[MAP]](%[[N]], %[[n]])
+// CHECK-DAG:             %[[applyMapJ:.*]] = affine.apply #[[MAP]](%[[C]], %[[c]])
+// CHECK:                 %[[scalar:.*]] = memref.load %arg0[%[[applyMapI]], %[[applyMapJ]]] : memref<128x256xf32>
+// CHECK:                 memref.store %[[scalar]], %arg1[%[[N]], %[[C]], %[[n]], %[[c]]] : memref<4x8x32x32xf32>
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
+
+// -----
+
+func.func @NC_to_NCnc_pad_static(%arg0: memref<13x15xf32>, %arg1: memref<2x8x8x2xf32>, %arg2: f32) {
+  linalg.pack %arg0 padding_value(%arg2 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %arg1 : memref<13x15xf32> -> memref<2x8x8x2xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %broadcast
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-DAG:   #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK-DAG:   #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
+// CHECK:       func.func @NC_to_NCnc_pad_static(
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C13:.*]] = arith.constant 13 : index
+// CHECK-DAG:     %[[C15:.*]] = arith.constant 15 : index
+// CHECK:           scf.for %[[N:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
+// CHECK:             scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:               scf.for %[[n:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:                 scf.for %[[c:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
+// CHECK-DAG:               %[[applyMapI:.*]] = affine.apply #[[MAP0]](%[[N]], %[[n]])
+// CHECK-DAG:               %[[applyMapJ:.*]] = affine.apply #[[MAP1]](%[[C]], %[[c]])
+// CHECK:                   %[[isIInBound:.*]] = arith.cmpi slt, %[[applyMapI]], %[[C13]] : index
+// CHECK:                   %[[isJInBound:.*]] = arith.cmpi slt, %[[applyMapJ]], %[[C15]] : index
+// CHECK:                   %[[isAllInBounds:.*]] = arith.andi %[[isIInBound]], %[[isJInBound]] : i1
+// CHECK:                   %[[scalar:.*]] = scf.if %[[isAllInBounds]] -> (f32) {
+// CHECK:                     %[[load:.*]] = memref.load %arg0[%[[applyMapI]], %[[applyMapJ]]] : memref<13x15xf32>
+// CHECK:                     scf.yield %[[load]]
+// CHECK:                   } else {
+// CHECK:                     scf.yield %arg2
+// CHECK:                   }
+// CHECK:                   memref.store %[[scalar]], %arg1[%[[N]], %[[C]], %[[n]], %[[c]]] : memref<2x8x8x2xf32>
+
+// -----
+
+func.func @KC_to_KCck(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) {
+  linalg.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 32] into %arg1 : memref<128x256xf32> -> memref<4x8x32x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %broadcast
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK:       #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK:       func.func @KC_to_KCck(
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C32:.*]] = arith.constant 32 : index
+// CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK:           scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:             scf.for %[[c:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK:               scf.for %[[k:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK-DAG:             %[[applyMapC:.*]] = affine.apply #[[MAP]](%[[C]], %[[c]])
+// CHECK-DAG:             %[[applyMapK:.*]] = affine.apply #[[MAP]](%[[K]], %[[k]])
+// CHECK:                 %[[scalar:.*]] = memref.load %arg0[%[[applyMapK]], %[[applyMapC]]] : memref<128x256xf32>
+// CHECK:                 memref.store %[[scalar]], %arg1[%[[K]], %[[C]], %[[c]], %[[k]]] : memref<4x8x32x32xf32>
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
+
+// -----
+
+// This should be a simple expand shape.
+func.func @KC_to_KCc(%arg0: memref<128x256xf32>, %arg1: memref<128x8x32xf32>) {
+  linalg.pack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %arg1 : memref<128x256xf32> -> memref<128x8x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %broadcast
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK:       #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK:       func.func @KC_to_KCc(
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C128:.*]] = arith.constant 128 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C32:.*]] = arith.constant 32 : index
+// CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[C128]] step %[[C1]] {
+// CHECK:           scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:             scf.for %[[c:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK:               %[[applyMapC:.*]] = affine.apply #[[MAP]](%[[C]], %[[c]])
+// CHECK:               %[[scalar:.*]] = memref.load %arg0[%[[K]], %[[applyMapC]]] : memref<128x256xf32>
+// CHECK:               memref.store %[[scalar]], %arg1[%[[K]],...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/178222


More information about the Mlir-commits mailing list