[Mlir-commits] [mlir] [MLIR][Linalg] pack, unpack to take memref inputs (PR #129036)

Hyunsung Lee llvmlistbot at llvm.org
Tue Apr 15 13:48:24 PDT 2025


https://github.com/ita9naiwa updated https://github.com/llvm/llvm-project/pull/129036

>From 4d523adc3cf5eb581c43395e66aaa0012dbc179b Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Thu, 27 Feb 2025 19:54:30 +0900
Subject: [PATCH 01/33] draft

---
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    | 72 +++++++++++++++++--
 .../Dialect/Linalg/IR/RelayoutOpInterface.td  |  1 +
 .../mlir/Dialect/Utils/ReshapeOpsUtils.h      |  4 +-
 mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp    |  2 +-
 4 files changed, 69 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 1e48a5e3a20ee..f8a4657c564ce 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -77,7 +77,20 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
     /// with `inner_dims_pos` rather than the packed tensor.
     SmallVector<int64_t> getTiledOuterDims();
   }];
-
+  let extraClassDeclaration = commonExtraClassDeclaration # [{
+    ShapedType getInputType() {
+        return cast<ShapedType>(getInput().getType());
+      }
+      ShapedType getOutputType() {
+        return cast<ShapedType>(getOutput().getType());
+      }
+      int64_t getInputRank() {
+        return getInputType().getRank();
+      }
+      int64_t getOutputRank() {
+        return getOutputType().getRank();
+      }
+    }];
   let hasVerifier = 1;
 }
 
@@ -152,14 +165,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     // Note: Only tiled dimensions can be padded.
     ```
   }];
-  let arguments = (ins AnyRankedTensor:$source,
-                       AnyRankedTensor:$dest,
+  let arguments = (ins AnyShaped:$source,
+                       AnyShaped:$dest,
                        Optional<AnyType>:$padding_value,
                        DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
                        DenseI64ArrayAttr:$inner_dims_pos,
                        Variadic<Index>:$inner_tiles,
                        DenseI64ArrayAttr:$static_inner_tiles);
-  let results = (outs AnyRankedTensor:$result);
+  let results = (outs AnyShaped:$result);
   let assemblyFormat = [{
     $source
     (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
@@ -179,6 +192,28 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
+    Value getOutput() {
+      return getDpsInitOperand(0)->get();
+    }
+
+    // Return the input operand.
+    Value getInput() {
+      return getDpsInputOperand(0)->get();
+    }
+    ShapedType getInputType() {
+      return cast<ShapedType>(getInput().getType());
+    }
+    ShapedType getOutputType() {
+      return cast<ShapedType>(getDest().getType());
+    }
+    int64_t getInputRank() {
+      return getInputType().getRank();
+    }
+    int64_t getOutputRank() {
+      return getOutputType().getRank();
+    }
+
+    LogicalResult generateScalarImplementation(OpBuilder &builder, Location loc, ValueRange ivs);
     // Method to get the shape of the result as `SmallVector<OpFoldResult>`.
     // This is a static method to allow getting the shape of the destination
     // expected while creating a `pack` op.
@@ -229,6 +264,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     /// 2. pads the other ones, and
     /// 3. doesn't shuffle the dimensions
     bool isLikePad();
+
   }];
 
   let hasCanonicalizeMethod = 1;
@@ -279,13 +315,13 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
         : tensor<8x16x8x32xf32> -> tensor<128x256xf32>
     ```
   }];
-  let arguments = (ins AnyRankedTensor:$source,
-                       AnyRankedTensor:$dest,
+  let arguments = (ins AnyShaped:$source,
+                       AnyShaped:$dest,
                        DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
                        DenseI64ArrayAttr:$inner_dims_pos,
                        Variadic<Index>:$inner_tiles,
                        DenseI64ArrayAttr:$static_inner_tiles);
-  let results = (outs AnyRankedTensor:$result);
+  let results = (outs AnyShaped:$result);
   let assemblyFormat = [{
     $source
     (`outer_dims_perm` `=` $outer_dims_perm^)?
@@ -303,6 +339,28 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
+    Value getOutput() {
+      return getDpsInitOperand(0)->get();
+    }
+
+    // Return the input operand.
+    Value getInput() {
+      return getDpsInputOperand(0)->get();
+    }
+    ShapedType getInputType() {
+      return cast<ShapedType>(getInput().getType());
+    }
+    ShapedType getOutputType() {
+      return cast<ShapedType>(getDest().getType()); // getDest() 사용
+    }
+    int64_t getInputRank() {
+      return getInputType().getRank();
+    }
+    int64_t getOutputRank() {
+      return getOutputType().getRank();
+    }
+    LogicalResult generateScalarImplementation(OpBuilder &builder, Location loc, ValueRange ivs);
+
     static Value createDestinationTensor(OpBuilder &b, Location loc,
         Value source, ArrayRef<OpFoldResult> innerTileSizes,
         ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
index 2dec2fc4396f4..467d862d277eb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
@@ -10,6 +10,7 @@
 #define LINALG_IR_RELAYOUTOPINTERFACE
 
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
 include "mlir/IR/OpBase.td"
 
 def LinalgRelayoutOpInterface : OpInterface<"RelayoutOpInterface"> {
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3af89a6ab3799..a86bf74a7b6a1 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -451,7 +451,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
 ///    %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
 ///          tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
 ///
-///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] : 
+///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
 ///          tensor<1x1x1x10xf32> into tensor<1x10xf32>
 ///    %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
 ///          tensor<1x10xf32> into tensor<10x10xf32>
@@ -573,7 +573,7 @@ PackingMetadata computePackingMetadata(int64_t packedRank,
 /// Removes the op and replaces the constant with a new constant of the result
 /// shape. When an optional cst attribute is passed, it is reshaped only if the
 /// splat value matches the value in the attribute.
-OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result,
+OpFoldResult reshapeConstantSource(DenseElementsAttr source, ShapedType result,
                                    std::optional<Attribute> cst = std::nullopt);
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 0336423c57b1d..4267732571801 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -485,7 +485,7 @@ PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
 }
 
 OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
-                                         TensorType result,
+  ShapedType result,
                                          std::optional<Attribute> cst) {
   if (source && source.isSplat() && result.hasStaticShape() &&
       (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))

>From 4f2dbf4848092942a7932387e39d3c1220d78923 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Fri, 28 Feb 2025 08:00:32 +0900
Subject: [PATCH 02/33] draft

---
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    | 44 -------------------
 1 file changed, 44 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index f8a4657c564ce..6e2c6171132f5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -192,28 +192,6 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
-    Value getOutput() {
-      return getDpsInitOperand(0)->get();
-    }
-
-    // Return the input operand.
-    Value getInput() {
-      return getDpsInputOperand(0)->get();
-    }
-    ShapedType getInputType() {
-      return cast<ShapedType>(getInput().getType());
-    }
-    ShapedType getOutputType() {
-      return cast<ShapedType>(getDest().getType());
-    }
-    int64_t getInputRank() {
-      return getInputType().getRank();
-    }
-    int64_t getOutputRank() {
-      return getOutputType().getRank();
-    }
-
-    LogicalResult generateScalarImplementation(OpBuilder &builder, Location loc, ValueRange ivs);
     // Method to get the shape of the result as `SmallVector<OpFoldResult>`.
     // This is a static method to allow getting the shape of the destination
     // expected while creating a `pack` op.
@@ -339,28 +317,6 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
-    Value getOutput() {
-      return getDpsInitOperand(0)->get();
-    }
-
-    // Return the input operand.
-    Value getInput() {
-      return getDpsInputOperand(0)->get();
-    }
-    ShapedType getInputType() {
-      return cast<ShapedType>(getInput().getType());
-    }
-    ShapedType getOutputType() {
-      return cast<ShapedType>(getDest().getType()); // getDest() 사용
-    }
-    int64_t getInputRank() {
-      return getInputType().getRank();
-    }
-    int64_t getOutputRank() {
-      return getOutputType().getRank();
-    }
-    LogicalResult generateScalarImplementation(OpBuilder &builder, Location loc, ValueRange ivs);
-
     static Value createDestinationTensor(OpBuilder &b, Location loc,
         Value source, ArrayRef<OpFoldResult> innerTileSizes,
         ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);

>From 226230c9445084671531d755d5c3f5612bed7d67 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Fri, 28 Feb 2025 08:01:05 +0900
Subject: [PATCH 03/33] draft

---
 .../mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td   | 15 +--------------
 1 file changed, 1 insertion(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 6e2c6171132f5..c68c395fc6337 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -77,20 +77,7 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
     /// with `inner_dims_pos` rather than the packed tensor.
     SmallVector<int64_t> getTiledOuterDims();
   }];
-  let extraClassDeclaration = commonExtraClassDeclaration # [{
-    ShapedType getInputType() {
-        return cast<ShapedType>(getInput().getType());
-      }
-      ShapedType getOutputType() {
-        return cast<ShapedType>(getOutput().getType());
-      }
-      int64_t getInputRank() {
-        return getInputType().getRank();
-      }
-      int64_t getOutputRank() {
-        return getOutputType().getRank();
-      }
-    }];
+
   let hasVerifier = 1;
 }
 

>From 0c184dfc85cdb0d89d62aa8cafc4f752e1acc654 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Fri, 28 Feb 2025 09:44:08 +0900
Subject: [PATCH 04/33] init

---
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    | 10 +++---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 14 ++++-----
 .../Transforms/PackAndUnpackPatterns.cpp      | 12 +++----
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 31 +++++++++++++++----
 .../Linalg/Transforms/Vectorization.cpp       |  2 +-
 mlir/lib/Tools/mlir-opt/launch.json           | 13 ++++++++
 6 files changed, 57 insertions(+), 25 deletions(-)
 create mode 100644 mlir/lib/Tools/mlir-opt/launch.json

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index c68c395fc6337..785c7cc924159 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -43,10 +43,10 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
   code commonExtraClassDeclaration = [{
     size_t getSourceRank() { return getSourceType().getRank(); };
     size_t getDestRank() { return getDestType().getRank(); };
-    RankedTensorType getSourceType() {
-      return ::llvm::cast<RankedTensorType>(getSource().getType()); };
-    RankedTensorType getDestType() {
-      return ::llvm::cast<RankedTensorType>(getDest().getType()); };
+    ShapedType getSourceType() {
+      return ::llvm::cast<ShapedType>(getSource().getType()); };
+    ShapedType getDestType() {
+      return ::llvm::cast<ShapedType>(getDest().getType()); };
 
     MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
 
@@ -190,7 +190,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     // Method to get the `RankedTensorType` of the result based on the inner
     // tiles, position of the inner tiles (innerDimsPos)  and interchange vector
     // of outer loops (outerDimsPerm).
-    static RankedTensorType inferPackedType(RankedTensorType sourceType,
+    static RankedTensorType inferPackedType(ShapedType sourceType,
         ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
         ArrayRef<int64_t> outerDimsPerm = {});
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..f4f08d9d4acf7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -803,7 +803,7 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
           rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
     }
 
-    RankedTensorType srcPadType = srcPadOp.getSourceType();
+    ShapedType srcPadType = srcPadOp.getSourceType();
     SmallVector<OpFoldResult, 4> newSizes;
     for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
       if (srcPadType.isDynamicDim(i)) {
@@ -4433,7 +4433,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
     return op->emitError("invalid zero tile factor");
 
   // Verify inner_dims_pos and outer_dims_perm.
-  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
+  ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
                                       ? packOrUnPack.getSourceType()
                                       : packOrUnPack.getDestType();
   size_t unpackedRank = unpackedType.getRank();
@@ -4747,7 +4747,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
 
 /// Get the expected packed type based on source type, tile factors, position of
 /// the inner tiles and permutation of the outer tiled loop.
-RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
+RankedTensorType PackOp::inferPackedType(ShapedType sourceType,
                                          ArrayRef<int64_t> innerTileSizes,
                                          ArrayRef<int64_t> innerDimsPos,
                                          ArrayRef<int64_t> outerDimsPerm) {
@@ -4943,7 +4943,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
           rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
     }
     Value dest = packOp.getDest();
-    RankedTensorType originalResultType = packOp.getDestType();
+    ShapedType originalResultType = packOp.getDestType();
     bool needUpdateDestType = (destShape != originalResultType.getShape());
     if (needUpdateDestType) {
       auto newDestType = packOp.getDestType().clone(destShape);
@@ -4953,7 +4953,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     rewriter.modifyOpInPlace(packOp, [&] {
       packOp.getSourceMutable().assign(source);
       packOp.getDestMutable().assign(dest);
-      packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+      packOp.getResult().setType(cast<ShapedType>(dest.getType()));
     });
     // Insert a cast if needed
     if (needUpdateDestType) {
@@ -4970,7 +4970,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
 
 template <typename PackOrUnpackOp>
 static bool isLikePadUnPad(PackOrUnpackOp packOp,
-                           RankedTensorType packedTensorType) {
+                           ShapedType packedTensorType) {
   static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
                     std::is_same<PackOrUnpackOp, UnPackOp>::value,
                 "Function meant for pack/unpack");
@@ -5274,7 +5274,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
 }
 
 bool UnPackOp::isLikeUnPad() {
-  RankedTensorType packedTensorType = getSourceType();
+  ShapedType packedTensorType = getSourceType();
   return isLikePadUnPad(*this, packedTensorType);
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 0984b6988b93b..599aa3b6668df 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -111,7 +111,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
     if (packOp.getPaddingValue())
       return rewriter.notifyMatchFailure(packOp, "expects no padding value");
 
-    RankedTensorType sourceType = packOp.getSourceType();
+    ShapedType sourceType = packOp.getSourceType();
     if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
         failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
                           packOp.getStaticTiles())) &&
@@ -119,7 +119,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
       return failure();
     }
 
-    RankedTensorType destType = packOp.getDestType();
+    ShapedType destType = packOp.getDestType();
     auto reassociation =
         getReassociationIndicesForReshape(sourceType, destType);
     if (!reassociation)
@@ -157,8 +157,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
           "expects outer_dims_perm is empty or an identity permutation");
     }
 
-    RankedTensorType sourceType = unpackOp.getSourceType();
-    RankedTensorType destType = unpackOp.getDestType();
+    ShapedType sourceType = unpackOp.getSourceType();
+    ShapedType destType = unpackOp.getDestType();
     if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
       return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
 
@@ -173,7 +173,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
 
   LogicalResult matchAndRewrite(UnPackOp unpackOp,
                                 PatternRewriter &rewriter) const override {
-    RankedTensorType destType = unpackOp.getDestType();
+    ShapedType destType = unpackOp.getDestType();
     if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
         failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
                           unpackOp.getStaticTiles())) &&
@@ -181,7 +181,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
       return failure();
     }
 
-    RankedTensorType sourceType = unpackOp.getSourceType();
+    ShapedType sourceType = unpackOp.getSourceType();
     auto reassociation =
         getReassociationIndicesForReshape(sourceType, destType);
     if (!reassociation)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index dcd50cc44f81b..7ed211841c53f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -359,7 +359,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unPackOp);
 
-  RankedTensorType packedTensorType = unPackOp.getSourceType();
+  ShapedType packedTensorType = unPackOp.getSourceType();
   int64_t packedRank = packedTensorType.getRank();
 
   OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -396,10 +396,29 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
   applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
 
   // 3. Transpose packedShape to stripMinedShape.
-  RankedTensorType stripMinedTensorType =
-      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
-  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
-      stripMinedTensorType, packingMetadata.reassociations);
+  ShapedType stripMinedType;
+  if (auto tensorType = packedTensorType.dyn_cast<TensorType>()) {
+    stripMinedType =
+        RankedTensorType::get(stripMinedShape, tensorType.getElementType());
+  } else if (auto memrefType = packedTensorType.dyn_cast<MemRefType>()) {
+    stripMinedType =
+        MemRefType::get(stripMinedShape, memrefType.getElementType());
+  }
+  ShapedType collapsedType;
+  if (stripMinedType.isa<TensorType>()) {
+    collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+        stripMinedType.cast<RankedTensorType>(),
+        packingMetadata.reassociations);
+  } else if (stripMinedType.isa<MemRefType>()) {
+    auto memrefTy = stripMinedType.cast<MemRefType>();
+    auto tensorTy =
+        RankedTensorType::get(memrefTy.getShape(), memrefTy.getElementType());
+    auto collapsedTensorType = tensor::CollapseShapeOp::inferCollapsedType(
+        tensorTy, packingMetadata.reassociations);
+    // tensor collapsed type을 memref로 재구성 (같은 메모리 공간 유지)
+    collapsedType = MemRefType::get(collapsedTensorType.getShape(),
+                                    collapsedTensorType.getElementType());
+  }
 
   // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
   // permutation.
@@ -407,7 +426,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
       tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
   applyPermutationToVector(dims, packedToStripMinedShapePerm);
   auto emptyOp = rewriter.create<tensor::EmptyOp>(
-      loc, dims, stripMinedTensorType.getElementType());
+      loc, dims, stripMinedType.getElementType());
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
       loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ae04c2b6b2a5b..25ad5e38addbe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1669,7 +1669,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unpackOp);
 
-  RankedTensorType unpackTensorType = unpackOp.getSourceType();
+  ShapedType unpackTensorType = unpackOp.getSourceType();
 
   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
   ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
diff --git a/mlir/lib/Tools/mlir-opt/launch.json b/mlir/lib/Tools/mlir-opt/launch.json
new file mode 100644
index 0000000000000..5a686d02e2dfb
--- /dev/null
+++ b/mlir/lib/Tools/mlir-opt/launch.json
@@ -0,0 +1,13 @@
+{
+    "version": "0.2.0",
+    "configurations": [
+        {
+            "name": "ma",
+            "type": "lldb",
+            "request": "launch",
+            "program": "/Users/ita/src/iree-build/tools/iree-opt --show-dialects",
+            "args": [],
+            "cwd": "${workspaceFolder}"
+        }
+    ]
+}

>From 19201c69e23578a69583bb98415f9c9583cb5c41 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Fri, 28 Feb 2025 14:50:46 +0900
Subject: [PATCH 05/33] lint

---
 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp |  1 -
 mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp        | 12 ++++++------
 mlir/lib/Tools/mlir-opt/launch.json               | 13 -------------
 3 files changed, 6 insertions(+), 20 deletions(-)
 delete mode 100644 mlir/lib/Tools/mlir-opt/launch.json

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 7ed211841c53f..36e01ef46b30b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -415,7 +415,6 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
         RankedTensorType::get(memrefTy.getShape(), memrefTy.getElementType());
     auto collapsedTensorType = tensor::CollapseShapeOp::inferCollapsedType(
         tensorTy, packingMetadata.reassociations);
-    // tensor collapsed type을 memref로 재구성 (같은 메모리 공간 유지)
     collapsedType = MemRefType::get(collapsedTensorType.getShape(),
                                     collapsedTensorType.getElementType());
   }
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 4267732571801..9a2bd3493f6af 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -315,11 +315,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
     // have proven that these are not sliced. In this case we just take
     // the full extent of each dimension in the reassociation list.
     if (linearizedDimensions[it.index()]) {
-      llvm::append_range(
-          offsetsSizesAndStrides,
-          llvm::map_range(it.value(), [&](int64_t idx) -> Range {
-            return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
-          }));
+      llvm::append_range(offsetsSizesAndStrides,
+                         llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+                           return {zeroAttr, collapseShapeInputShape[idx],
+                                   oneAttr};
+                         }));
       continue;
     }
 
@@ -485,7 +485,7 @@ PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
 }
 
 OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
-  ShapedType result,
+                                         ShapedType result,
                                          std::optional<Attribute> cst) {
   if (source && source.isSplat() && result.hasStaticShape() &&
       (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
diff --git a/mlir/lib/Tools/mlir-opt/launch.json b/mlir/lib/Tools/mlir-opt/launch.json
deleted file mode 100644
index 5a686d02e2dfb..0000000000000
--- a/mlir/lib/Tools/mlir-opt/launch.json
+++ /dev/null
@@ -1,13 +0,0 @@
-{
-    "version": "0.2.0",
-    "configurations": [
-        {
-            "name": "ma",
-            "type": "lldb",
-            "request": "launch",
-            "program": "/Users/ita/src/iree-build/tools/iree-opt --show-dialects",
-            "args": [],
-            "cwd": "${workspaceFolder}"
-        }
-    ]
-}

>From b99b92030f2f664607f43554d2b7bc722c98c2c1 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Fri, 28 Feb 2025 15:26:13 +0900
Subject: [PATCH 06/33] lint

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index f4f08d9d4acf7..eca8cea3e6323 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4434,8 +4434,8 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
 
   // Verify inner_dims_pos and outer_dims_perm.
   ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
-                                      ? packOrUnPack.getSourceType()
-                                      : packOrUnPack.getDestType();
+                                ? packOrUnPack.getSourceType()
+                                : packOrUnPack.getDestType();
   size_t unpackedRank = unpackedType.getRank();
   ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
   ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
@@ -4969,8 +4969,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
 }
 
 template <typename PackOrUnpackOp>
-static bool isLikePadUnPad(PackOrUnpackOp packOp,
-                           ShapedType packedTensorType) {
+static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
   static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
                     std::is_same<PackOrUnpackOp, UnPackOp>::value,
                 "Function meant for pack/unpack");

>From be6a1193579633d7b678a30a9a80e5dee89a51e1 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Fri, 28 Feb 2025 16:19:20 +0900
Subject: [PATCH 07/33] add

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index eca8cea3e6323..a19039fbca67d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5001,9 +5001,12 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
 }
 
 bool PackOp::isLikePad() {
-  auto packedTensorType =
-      llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
-  return isLikePadUnPad(*this, packedTensorType);
+  if (auto packedTensorType =
+          llvm::dyn_cast<RankedTensorType>((*this)->getResultTypes().front()))
+    return isLikePadUnPad(*this, packedTensorType);
+  if (auto packedTensorType =
+          llvm::dyn_cast<MemRefType>((*this)->getResultTypes().front()))
+    return isLikePadUnPad(*this, packedTensorType);
 }
 
 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {

>From eee8805c351e7b8100d3e73d1e67c1c06e065962 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 1 Mar 2025 09:04:26 +0900
Subject: [PATCH 08/33] remove tensor casting

---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |  5 +++
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 10 ++----
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 32 ++++++++++++++++++-
 3 files changed, 39 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 4c8a214049ea9..8bcc1882b454d 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1799,6 +1799,11 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
 
     static MemRefType computeCollapsedType(
         MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
+    static MemRefType
+        inferCollapsedType(MemRefType type, ArrayRef<AffineMap> reassociation);
+    static MemRefType
+        inferCollapsedType(MemRefType type,
+                           SmallVector<ReassociationIndices> reassociation);
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 36e01ef46b30b..efa0453dda036 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
@@ -410,13 +411,8 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
         stripMinedType.cast<RankedTensorType>(),
         packingMetadata.reassociations);
   } else if (stripMinedType.isa<MemRefType>()) {
-    auto memrefTy = stripMinedType.cast<MemRefType>();
-    auto tensorTy =
-        RankedTensorType::get(memrefTy.getShape(), memrefTy.getElementType());
-    auto collapsedTensorType = tensor::CollapseShapeOp::inferCollapsedType(
-        tensorTy, packingMetadata.reassociations);
-    collapsedType = MemRefType::get(collapsedTensorType.getShape(),
-                                    collapsedTensorType.getElementType());
+    collapsedType = memref::CollapseShapeOp::inferCollapsedType(
+        stripMinedType.cast<MemRefType>(), packingMetadata.reassociations);
   }
 
   // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11597505e7888..ba12cc34d6457 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
@@ -1124,7 +1125,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
         }
       } // else dim.getIndex is a block argument to reshape->getBlock and
         // dominates reshape
-    }   // Check condition 2
+    } // Check condition 2
     else if (dim->getBlock() != reshape->getBlock() &&
              !dim.getIndex().getParentRegion()->isProperAncestor(
                  reshape->getParentRegion())) {
@@ -2525,6 +2526,35 @@ MemRefType CollapseShapeOp::computeCollapsedType(
                          srcType.getMemorySpace());
 }
 
+MemRefType
+CollapseShapeOp::inferCollapsedType(MemRefType type,
+                                    ArrayRef<AffineMap> reassociation) {
+  auto shape = type.getShape();
+  SmallVector<int64_t, 4> newShape;
+  assert(isReassociationValid(reassociation) && "invalid reassociation");
+  unsigned currentDim = 0;
+  for (AffineMap m : reassociation) {
+    unsigned dim = m.getNumResults();
+    auto band = shape.slice(currentDim, dim);
+    int64_t size = 1;
+    if (llvm::is_contained(band, ShapedType::kDynamic))
+      size = ShapedType::kDynamic;
+    else
+      for (unsigned d = 0; d < dim; ++d)
+        size *= shape[currentDim + d];
+    newShape.push_back(size);
+    currentDim += dim;
+  }
+  return MemRefType::get(newShape, type.getElementType());
+}
+
+MemRefType CollapseShapeOp::inferCollapsedType(
+    MemRefType type, SmallVector<ReassociationIndices> reassociation) {
+  return inferCollapsedType(
+      type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+                type.getContext(), reassociation)));
+}
+
 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
                             ArrayRef<ReassociationIndices> reassociation,
                             ArrayRef<NamedAttribute> attrs) {

>From c5b3c3955321ef0e9211226c8fea017bd4b591bf Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 1 Mar 2025 09:39:30 +0900
Subject: [PATCH 09/33] add test

---
 .../lib/Dialect/Linalg/Transforms/Transforms.cpp |  5 ++---
 mlir/test/Dialect/Linalg/loops.mlir              | 16 ++++++++++++++++
 2 files changed, 18 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index efa0453dda036..98dab332b2f40 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -408,11 +408,10 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
   ShapedType collapsedType;
   if (stripMinedType.isa<TensorType>()) {
     collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
-        stripMinedType.cast<RankedTensorType>(),
-        packingMetadata.reassociations);
+        cast<RankedTensorType>(stripMinedType), packingMetadata.reassociations);
   } else if (stripMinedType.isa<MemRefType>()) {
     collapsedType = memref::CollapseShapeOp::inferCollapsedType(
-        stripMinedType.cast<MemRefType>(), packingMetadata.reassociations);
+        cast<MemRefType>(stripMinedType), packingMetadata.reassociations);
   }
 
   // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index efe8010cffc91..767f593329f52 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -942,3 +942,19 @@ func.func @transpose(%input: memref<?xf32>,
 //      CHECKPARALLEL:      }
 //      CHECKPARALLEL:      return
 //      CHECKPARALLEL:    }
+
+// Test that we can lower all the way to LLVM without crashing, don't check results here.
+func.func @pack_memref(%source: memref<128x256xf32>) -> memref<8x16x8x32xf32> {
+  %dest = memref.alloc() : memref<8x16x8x32xf32>
+  %packed = linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+      into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
+  return %packed : memref<8x16x8x32xf32>
+}
+
+// Test that we can lower all the way to LLVM without crashing, don't check results here.
+func.func @unpack_memref(%source: memref<16x8x8x32xf32>) -> memref<128x256xf32> {
+  %dest = memref.alloc() : memref<128x256xf32>
+  %unpacked = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+      into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
+  return %unpacked : memref<128x256xf32>
+}
\ No newline at end of file

>From a5d01dffda768947463451af6cab1cf6e282114e Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Sun, 16 Mar 2025 21:21:41 +0900
Subject: [PATCH 10/33] fix upon review

---
 .../Dialect/Linalg/IR/RelayoutOpInterface.td  |  1 -
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |  7 +--
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 14 +++--
 .../Transforms/PackAndUnpackPatterns.cpp      | 24 +++++---
 .../Dialect/Linalg/Transforms/Transforms.cpp  |  2 +-
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 56 +++++++++----------
 mlir/test/Dialect/Linalg/loops.mlir           | 16 ------
 mlir/test/Dialect/Linalg/roundtrip.mlir       | 18 ++++++
 8 files changed, 71 insertions(+), 67 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
index 467d862d277eb..2dec2fc4396f4 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
@@ -10,7 +10,6 @@
 #define LINALG_IR_RELAYOUTOPINTERFACE
 
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
-include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
 include "mlir/IR/OpBase.td"
 
 def LinalgRelayoutOpInterface : OpInterface<"RelayoutOpInterface"> {
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 87564066d309d..93449766aca4e 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1782,11 +1782,6 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
 
     static MemRefType computeCollapsedType(
         MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
-    static MemRefType
-        inferCollapsedType(MemRefType type, ArrayRef<AffineMap> reassociation);
-    static MemRefType
-        inferCollapsedType(MemRefType type,
-                           SmallVector<ReassociationIndices> reassociation);
   }];
 
   let hasVerifier = 1;
@@ -1806,7 +1801,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
   let summary = "store operation";
   let description = [{
     The `store` op stores an element into a memref at the specified indices.
-    
+
     The number of indices must match the rank of the memref. The indices must
     be in-bounds: `0 <= idx < dim_size`
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a19039fbca67d..b4cbc7c6ad8e9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5001,12 +5001,8 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
 }
 
 bool PackOp::isLikePad() {
-  if (auto packedTensorType =
-          llvm::dyn_cast<RankedTensorType>((*this)->getResultTypes().front()))
-    return isLikePadUnPad(*this, packedTensorType);
-  if (auto packedTensorType =
-          llvm::dyn_cast<MemRefType>((*this)->getResultTypes().front()))
-    return isLikePadUnPad(*this, packedTensorType);
+  auto packedTensorType = llvm::dyn_cast<ShapedType>((*this)->getResultTypes().front());
+  return isLikePadUnPad(*this, packedTensorType);
 }
 
 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
@@ -5042,6 +5038,9 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
     if (!tensor::hasFoldableTensorCastOperand(op))
       return failure();
 
+    if (!op.hasPureTensorSemantics())
+      return failure();
+
     SmallVector<Type> newResultTypes(op->getResultTypes());
     SmallVector<Value> newOperands =
         tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
@@ -5310,6 +5309,9 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
     if (!tensor::hasFoldableTensorCastOperand(op))
       return failure();
 
+    if (!op.hasPureTensorSemantics())
+      return failure();
+
     SmallVector<Type> newResultTypes(op->getResultTypes());
     SmallVector<Value> newOperands =
         tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 599aa3b6668df..59e4b2ff634c2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -171,25 +171,27 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
     return success();
   }
 
-  LogicalResult matchAndRewrite(UnPackOp unpackOp,
+  LogicalResult matchAndRewrite(UnPackOp unPackOp,
                                 PatternRewriter &rewriter) const override {
-    ShapedType destType = unpackOp.getDestType();
-    if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
-        failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
-                          unpackOp.getStaticTiles())) &&
-        !unpackOp.isLikeUnPad()) {
+    if (!unPackOp.hasPureTensorSemantics())
+      return failure();
+    ShapedType destType = unPackOp.getDestType();
+    if (failed(isUnpackOnInnerMostDim(rewriter, unPackOp)) &&
+        failed(isPackOn1D(rewriter, unPackOp, destType.getShape(),
+                          unPackOp.getStaticTiles())) &&
+        !unPackOp.isLikeUnPad()) {
       return failure();
     }
 
-    ShapedType sourceType = unpackOp.getSourceType();
+    ShapedType sourceType = unPackOp.getSourceType();
     auto reassociation =
         getReassociationIndicesForReshape(sourceType, destType);
     if (!reassociation)
       return failure();
     Value collapsed = insertCollapse(
-        rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
+        rewriter, unPackOp.getLoc(), unPackOp.getSource(), destType,
         getReassociationIndicesAttribute(rewriter, *reassociation));
-    rewriter.replaceOp(unpackOp, collapsed);
+    rewriter.replaceOp(unPackOp, collapsed);
     return success();
   }
 };
@@ -426,6 +428,8 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
 
   LogicalResult matchAndRewrite(UnPackOp unPackOp,
                                 PatternRewriter &rewriter) const override {
+    if (!unPackOp.hasPureTensorSemantics())
+      return failure();
     auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
     if (!linalgOp)
       return failure();
@@ -507,6 +511,8 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
 
   LogicalResult matchAndRewrite(UnPackOp unPackOp,
                                 PatternRewriter &rewriter) const override {
+    if (!unPackOp.hasPureTensorSemantics())
+      return failure();
     // Check for tensor.empty source.
     auto emptyOp = unPackOp.getSource().getDefiningOp<tensor::EmptyOp>();
     if (!emptyOp)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 98dab332b2f40..105831a3d9259 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -410,7 +410,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
     collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
         cast<RankedTensorType>(stripMinedType), packingMetadata.reassociations);
   } else if (stripMinedType.isa<MemRefType>()) {
-    collapsedType = memref::CollapseShapeOp::inferCollapsedType(
+    collapsedType = memref::CollapseShapeOp::computeCollapsedType(
         cast<MemRefType>(stripMinedType), packingMetadata.reassociations);
   }
 
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index ba12cc34d6457..03c08756d110b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2526,34 +2526,34 @@ MemRefType CollapseShapeOp::computeCollapsedType(
                          srcType.getMemorySpace());
 }
 
-MemRefType
-CollapseShapeOp::inferCollapsedType(MemRefType type,
-                                    ArrayRef<AffineMap> reassociation) {
-  auto shape = type.getShape();
-  SmallVector<int64_t, 4> newShape;
-  assert(isReassociationValid(reassociation) && "invalid reassociation");
-  unsigned currentDim = 0;
-  for (AffineMap m : reassociation) {
-    unsigned dim = m.getNumResults();
-    auto band = shape.slice(currentDim, dim);
-    int64_t size = 1;
-    if (llvm::is_contained(band, ShapedType::kDynamic))
-      size = ShapedType::kDynamic;
-    else
-      for (unsigned d = 0; d < dim; ++d)
-        size *= shape[currentDim + d];
-    newShape.push_back(size);
-    currentDim += dim;
-  }
-  return MemRefType::get(newShape, type.getElementType());
-}
-
-MemRefType CollapseShapeOp::inferCollapsedType(
-    MemRefType type, SmallVector<ReassociationIndices> reassociation) {
-  return inferCollapsedType(
-      type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
-                type.getContext(), reassociation)));
-}
+// MemRefType
+// CollapseShapeOp::inferCollapsedType(MemRefType type,
+//                                     ArrayRef<AffineMap> reassociation) {
+//   auto shape = type.getShape();
+//   SmallVector<int64_t, 4> newShape;
+//   assert(isReassociationValid(reassociation) && "invalid reassociation");
+//   unsigned currentDim = 0;
+//   for (AffineMap m : reassociation) {
+//     unsigned dim = m.getNumResults();
+//     auto band = shape.slice(currentDim, dim);
+//     int64_t size = 1;
+//     if (llvm::is_contained(band, ShapedType::kDynamic))
+//       size = ShapedType::kDynamic;
+//     else
+//       for (unsigned d = 0; d < dim; ++d)
+//         size *= shape[currentDim + d];
+//     newShape.push_back(size);
+//     currentDim += dim;
+//   }
+//   return MemRefType::get(newShape, type.getElementType());
+// }
+
+// MemRefType CollapseShapeOp::inferCollapsedType(
+//     MemRefType type, SmallVector<ReassociationIndices> reassociation) {
+//   return inferCollapsedType(
+//       type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+//                 type.getContext(), reassociation)));
+// }
 
 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
                             ArrayRef<ReassociationIndices> reassociation,
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 767f593329f52..efe8010cffc91 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -942,19 +942,3 @@ func.func @transpose(%input: memref<?xf32>,
 //      CHECKPARALLEL:      }
 //      CHECKPARALLEL:      return
 //      CHECKPARALLEL:    }
-
-// Test that we can lower all the way to LLVM without crashing, don't check results here.
-func.func @pack_memref(%source: memref<128x256xf32>) -> memref<8x16x8x32xf32> {
-  %dest = memref.alloc() : memref<8x16x8x32xf32>
-  %packed = linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
-      into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
-  return %packed : memref<8x16x8x32xf32>
-}
-
-// Test that we can lower all the way to LLVM without crashing, don't check results here.
-func.func @unpack_memref(%source: memref<16x8x8x32xf32>) -> memref<128x256xf32> {
-  %dest = memref.alloc() : memref<128x256xf32>
-  %unpacked = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
-      into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
-  return %unpacked : memref<128x256xf32>
-}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index dc556761b09e5..7f7aa12534a9b 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -706,3 +706,21 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt:
 // CHECK-LABEL: func @conv2d_channel_first_q_promote(
 // CHECK:   %[[arg0:[a-zA-z0-9]*]]: tensor<100x3x224x224xi8>, %[[arg1:[a-zA-z0-9]*]]: tensor<64x3x5x5xi8>, %[[arg2:[a-zA-z0-9]*]]: i8, %[[arg3:[a-zA-z0-9]*]]: i8)
 // CHECK:         linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32>
+
+// -----
+// Test that we can lower all the way to LLVM without crashing, don't check results here.
+func.func @pack_memref(%source: memref<128x256xf32>) -> memref<8x16x8x32xf32> {
+  %dest = memref.alloc() : memref<8x16x8x32xf32>
+  %packed = linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+      into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
+  return %packed : memref<8x16x8x32xf32>
+}
+
+// -----
+// Test that we can lower all the way to LLVM without crashing, don't check results here.
+func.func @unpack_memref(%source: memref<16x8x8x32xf32>) -> memref<128x256xf32> {
+  %dest = memref.alloc() : memref<128x256xf32>
+  %unpacked = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+      into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
+  return %unpacked : memref<128x256xf32>
+}

>From 2480616ebfbb968d83ab119bf7d6a84897f482e5 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Sun, 23 Mar 2025 15:09:40 +0900
Subject: [PATCH 11/33] lint

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 3 ++-
 mlir/test/Dialect/Linalg/roundtrip.mlir  | 8 ++++----
 2 files changed, 6 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b4cbc7c6ad8e9..8d71cc0142556 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5001,7 +5001,8 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
 }
 
 bool PackOp::isLikePad() {
-  auto packedTensorType = llvm::dyn_cast<ShapedType>((*this)->getResultTypes().front());
+  auto packedTensorType =
+      llvm::dyn_cast<ShapedType>((*this)->getResultTypes().front());
   return isLikePadUnPad(*this, packedTensorType);
 }
 
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 7f7aa12534a9b..c2e9e3fbd5423 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -711,16 +711,16 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt:
 // Test that we can lower all the way to LLVM without crashing, don't check results here.
 func.func @pack_memref(%source: memref<128x256xf32>) -> memref<8x16x8x32xf32> {
   %dest = memref.alloc() : memref<8x16x8x32xf32>
-  %packed = linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+  linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
       into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
-  return %packed : memref<8x16x8x32xf32>
+  return %dest : memref<8x16x8x32xf32>
 }
 
 // -----
 // Test that we can lower all the way to LLVM without crashing, don't check results here.
 func.func @unpack_memref(%source: memref<16x8x8x32xf32>) -> memref<128x256xf32> {
   %dest = memref.alloc() : memref<128x256xf32>
-  %unpacked = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+  linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
       into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
-  return %unpacked : memref<128x256xf32>
+  return %dest : memref<128x256xf32>
 }

>From 7b92a4ee2af6c15035dbb5824f23f2524c7aa1a3 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Mon, 24 Mar 2025 10:37:02 +0900
Subject: [PATCH 12/33] format fix

---
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    |  1 -
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 29 -------------------
 2 files changed, 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 785c7cc924159..63d36ec1fd3d6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -229,7 +229,6 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     /// 2. pads the other ones, and
     /// 3. doesn't shuffle the dimensions
     bool isLikePad();
-
   }];
 
   let hasCanonicalizeMethod = 1;
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 70d44904788b1..dbd3f6d631a8a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2526,35 +2526,6 @@ MemRefType CollapseShapeOp::computeCollapsedType(
                          srcType.getMemorySpace());
 }
 
-// MemRefType
-// CollapseShapeOp::inferCollapsedType(MemRefType type,
-//                                     ArrayRef<AffineMap> reassociation) {
-//   auto shape = type.getShape();
-//   SmallVector<int64_t, 4> newShape;
-//   assert(isReassociationValid(reassociation) && "invalid reassociation");
-//   unsigned currentDim = 0;
-//   for (AffineMap m : reassociation) {
-//     unsigned dim = m.getNumResults();
-//     auto band = shape.slice(currentDim, dim);
-//     int64_t size = 1;
-//     if (llvm::is_contained(band, ShapedType::kDynamic))
-//       size = ShapedType::kDynamic;
-//     else
-//       for (unsigned d = 0; d < dim; ++d)
-//         size *= shape[currentDim + d];
-//     newShape.push_back(size);
-//     currentDim += dim;
-//   }
-//   return MemRefType::get(newShape, type.getElementType());
-// }
-
-// MemRefType CollapseShapeOp::inferCollapsedType(
-//     MemRefType type, SmallVector<ReassociationIndices> reassociation) {
-//   return inferCollapsedType(
-//       type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
-//                 type.getContext(), reassociation)));
-// }
-
 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
                             ArrayRef<ReassociationIndices> reassociation,
                             ArrayRef<NamedAttribute> attrs) {

>From 6dc08ae1628ab2c5795f17af1a3b1ff682e5d861 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Tue, 25 Mar 2025 14:58:06 +0900
Subject: [PATCH 13/33] revert changes

---
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    |  9 ++++++-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  6 +++++
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 27 +++++--------------
 .../Linalg/Transforms/Vectorization.cpp       |  2 +-
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  3 +--
 mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp    | 10 +++----
 6 files changed, 28 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 63d36ec1fd3d6..03da3d38ef4c5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -34,7 +34,7 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
       Op<Linalg_Dialect, mnemonic, !listconcat(traits, [
         DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
         DestinationStyleOpInterface, LinalgRelayoutOpInterface,
-        ConditionallySpeculatable, NoMemoryEffect,
+        ConditionallySpeculatable, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
         DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
         TypesMatchWith<"result type matches type of dest",
                    "dest", "result",
@@ -76,6 +76,13 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
     /// have been tiled. Also, the order of the output dimensions is consistent
     /// with `inner_dims_pos` rather than the packed tensor.
     SmallVector<int64_t> getTiledOuterDims();
+
+         void $cppClass::getEffects(
+         SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+             &effects) {
+       getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+     }
+
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9766c6e56fb7c..1515d648bddca 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4822,6 +4822,9 @@ bool areTilesAndTiledDimsAllConstant(OpTy op) {
 }
 
 Speculation::Speculatability PackOp::getSpeculatability() {
+  if (!hasPureTensorSemantics())
+    return Speculation::NotSpeculatable;
+
   if (getPaddingValue())
     return Speculation::Speculatable;
 
@@ -5122,6 +5125,9 @@ LogicalResult UnPackOp::verify() {
 }
 
 Speculation::Speculatability UnPackOp::getSpeculatability() {
+  if (!hasPureTensorSemantics())
+    return Speculation::NotSpeculatable;
+
   // See PackOp::getSpeculatability.
   if (!areTilesAndTiledDimsAllConstant(*this))
     return Speculation::NotSpeculatable;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 105831a3d9259..085d6e44d854d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -17,7 +17,6 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
@@ -360,7 +359,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unPackOp);
 
-  ShapedType packedTensorType = unPackOp.getSourceType();
+  RankedTensorType packedTensorType = unPackOp.getSourceType();
   int64_t packedRank = packedTensorType.getRank();
 
   OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -397,22 +396,10 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
   applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
 
   // 3. Transpose packedShape to stripMinedShape.
-  ShapedType stripMinedType;
-  if (auto tensorType = packedTensorType.dyn_cast<TensorType>()) {
-    stripMinedType =
-        RankedTensorType::get(stripMinedShape, tensorType.getElementType());
-  } else if (auto memrefType = packedTensorType.dyn_cast<MemRefType>()) {
-    stripMinedType =
-        MemRefType::get(stripMinedShape, memrefType.getElementType());
-  }
-  ShapedType collapsedType;
-  if (stripMinedType.isa<TensorType>()) {
-    collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
-        cast<RankedTensorType>(stripMinedType), packingMetadata.reassociations);
-  } else if (stripMinedType.isa<MemRefType>()) {
-    collapsedType = memref::CollapseShapeOp::computeCollapsedType(
-        cast<MemRefType>(stripMinedType), packingMetadata.reassociations);
-  }
+  RankedTensorType stripMinedTensorType =
+      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
+  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+      stripMinedTensorType, packingMetadata.reassociations);
 
   // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
   // permutation.
@@ -420,7 +407,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
       tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
   applyPermutationToVector(dims, packedToStripMinedShapePerm);
   auto emptyOp = rewriter.create<tensor::EmptyOp>(
-      loc, dims, stripMinedType.getElementType());
+      loc, dims, stripMinedTensorType.getElementType());
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
       loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
 
@@ -1675,4 +1662,4 @@ void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) {
 
 void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) {
   patterns.add<DecomposePadOpPattern>(patterns.getContext());
-}
+}
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index dfb3f0c90595d..2dcd897330d1e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1669,7 +1669,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unpackOp);
 
-  ShapedType unpackTensorType = unpackOp.getSourceType();
+  RankedTensorType unpackTensorType = unpackOp.getSourceType();
 
   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
   ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index dbd3f6d631a8a..1a584a387f2a5 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -9,7 +9,6 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
@@ -1125,7 +1124,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
         }
       } // else dim.getIndex is a block argument to reshape->getBlock and
         // dominates reshape
-    } // Check condition 2
+    }  // Check condition 2
     else if (dim->getBlock() != reshape->getBlock() &&
              !dim.getIndex().getParentRegion()->isProperAncestor(
                  reshape->getParentRegion())) {
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 9a2bd3493f6af..cd0cdd378c352 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -315,11 +315,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
     // have proven that these are not sliced. In this case we just take
     // the full extent of each dimension in the reassociation list.
     if (linearizedDimensions[it.index()]) {
-      llvm::append_range(offsetsSizesAndStrides,
-                         llvm::map_range(it.value(), [&](int64_t idx) -> Range {
-                           return {zeroAttr, collapseShapeInputShape[idx],
-                                   oneAttr};
-                         }));
+      llvm::append_range(
+          offsetsSizesAndStrides,
+          llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+            return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
+          }));
       continue;
     }
 

>From cf7be5780250547577c8eca7c0c021f9590516a9 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Tue, 25 Mar 2025 15:03:54 +0900
Subject: [PATCH 14/33] revert changes

---
 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 2 +-
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp          | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 085d6e44d854d..dcd50cc44f81b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1662,4 +1662,4 @@ void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) {
 
 void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) {
   patterns.add<DecomposePadOpPattern>(patterns.getContext());
-}
\ No newline at end of file
+}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1a584a387f2a5..59434dccc117b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1124,7 +1124,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
         }
       } // else dim.getIndex is a block argument to reshape->getBlock and
         // dominates reshape
-    }  // Check condition 2
+    }   // Check condition 2
     else if (dim->getBlock() != reshape->getBlock() &&
              !dim.getIndex().getParentRegion()->isProperAncestor(
                  reshape->getParentRegion())) {

>From 4e2f00de633fbde83d6cc967c442c75d809f0536 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Tue, 25 Mar 2025 15:45:58 +0900
Subject: [PATCH 15/33] nit

---
 mlir/test/Dialect/Linalg/roundtrip.mlir | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index c2e9e3fbd5423..d8e11d03bedd4 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -709,7 +709,7 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt:
 
 // -----
 // Test that we can lower all the way to LLVM without crashing, don't check results here.
-func.func @pack_memref(%source: memref<128x256xf32>) -> memref<8x16x8x32xf32> {
+func.func @pack_memref(%source: memref<128x256xf32>, memref<8x16x8x32xf32>) -> memref<8x16x8x32xf32> {
   %dest = memref.alloc() : memref<8x16x8x32xf32>
   linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
       into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
@@ -718,8 +718,7 @@ func.func @pack_memref(%source: memref<128x256xf32>) -> memref<8x16x8x32xf32> {
 
 // -----
 // Test that we can lower all the way to LLVM without crashing, don't check results here.
-func.func @unpack_memref(%source: memref<16x8x8x32xf32>) -> memref<128x256xf32> {
-  %dest = memref.alloc() : memref<128x256xf32>
+func.func @unpack_memref(%source: memref<16x8x8x32xf32>, %dest: memref<128x256xf32>) -> memref<128x256xf32> {
   linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
       into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
   return %dest : memref<128x256xf32>

>From ee7a42a0c739bd4c56d0ce82318199ea01874491 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Thu, 27 Mar 2025 14:09:08 +0900
Subject: [PATCH 16/33] fix upon review: Add getEffects for PackOp and UnPackOp

---
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    |  7 ---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 54 +++++++++++++++++++
 .../Dialect/Linalg/Transforms/Transforms.cpp  |  3 +-
 .../Linalg/Transforms/Vectorization.cpp       |  3 +-
 mlir/test/Dialect/Linalg/roundtrip.mlir       |  3 +-
 5 files changed, 59 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 03da3d38ef4c5..980e99872b9a6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -76,13 +76,6 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
     /// have been tiled. Also, the order of the output dimensions is consistent
     /// with `inner_dims_pos` rather than the packed tensor.
     SmallVector<int64_t> getTiledOuterDims();
-
-         void $cppClass::getEffects(
-         SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
-             &effects) {
-       getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
-     }
-
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1515d648bddca..93ca2581f2a3d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4803,6 +4803,60 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
                           getPaddingValue(), metadata.outerDimsPerm);
 }
 
+void PackOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  // No memory effects for pure tensor semantics
+  if (hasPureTensorSemantics())
+    return;
+
+  for (OpOperand &opOperand : getOperation()->getOpOperands()) {
+    if (!llvm::isa<MemRefType>(opOperand.get().getType()))
+      continue;
+
+    if (&opOperand == &getSourceMutable()) {
+      effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
+                         /*effectOnFullRegion=*/true,
+                         SideEffects::DefaultResource::get());
+    }
+    else if (&opOperand == &getDestMutable()) {
+      effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
+                         /*effectOnFullRegion=*/true,
+                         SideEffects::DefaultResource::get());
+      effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0,
+                         /*effectOnFullRegion=*/true,
+                         SideEffects::DefaultResource::get());
+    }
+  }
+}
+
+void UnPackOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  // No memory effects for pure tensor semantics
+  if (hasPureTensorSemantics())
+    return;
+
+  for (OpOperand &opOperand : getOperation()->getOpOperands()) {
+    if (!llvm::isa<MemRefType>(opOperand.get().getType()))
+      continue;
+
+    if (&opOperand == &getSourceMutable()) {
+      effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
+                         /*effectOnFullRegion=*/true,
+                         SideEffects::DefaultResource::get());
+    }
+    else if (&opOperand == &getDestMutable()) {
+      effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
+                         /*effectOnFullRegion=*/true,
+                         SideEffects::DefaultResource::get());
+      effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0,
+                         /*effectOnFullRegion=*/true,
+                         SideEffects::DefaultResource::get());
+    }
+  }
+}
+
 /// Returns true if the tiles and the tiled dims are constant.
 template <typename OpTy>
 bool areTilesAndTiledDimsAllConstant(OpTy op) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index dcd50cc44f81b..2ae6474cf3a2f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -359,7 +359,8 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unPackOp);
 
-  RankedTensorType packedTensorType = unPackOp.getSourceType();
+  // TODO: support non-ranked tensor types. ShapedType
+  RankedTensorType packedTensorType = dyn_cast<RankedTensorType>(unPackOp.getSourceType());
   int64_t packedRank = packedTensorType.getRank();
 
   OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2dcd897330d1e..3b91b897bcfd4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1669,7 +1669,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unpackOp);
 
-  RankedTensorType unpackTensorType = unpackOp.getSourceType();
+  // TODO: support non-ranked tensor types. ShapedType
+  RankedTensorType unpackTensorType = dyn_cast<RankedTensorType>(unpackOp.getSourceType());
 
   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
   ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index d8e11d03bedd4..7ca20f684583a 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -709,8 +709,7 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt:
 
 // -----
 // Test that we can lower all the way to LLVM without crashing, don't check results here.
-func.func @pack_memref(%source: memref<128x256xf32>, memref<8x16x8x32xf32>) -> memref<8x16x8x32xf32> {
-  %dest = memref.alloc() : memref<8x16x8x32xf32>
+func.func @pack_memref(%source: memref<128x256xf32>, %dest: memref<8x16x8x32xf32>) -> memref<8x16x8x32xf32> {
   linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
       into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
   return %dest : memref<8x16x8x32xf32>

>From 5b95ee88d4bd1e4304c73383c3c03308598d0ae6 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Thu, 27 Mar 2025 14:15:52 +0900
Subject: [PATCH 17/33] make clang-format happy

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 30 +++++++++----------
 .../Dialect/Linalg/Transforms/Transforms.cpp  |  3 +-
 2 files changed, 16 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 93ca2581f2a3d..7587178dd94d2 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4816,16 +4816,15 @@ void PackOp::getEffects(
 
     if (&opOperand == &getSourceMutable()) {
       effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
-                         /*effectOnFullRegion=*/true,
-                         SideEffects::DefaultResource::get());
-    }
-    else if (&opOperand == &getDestMutable()) {
+                           /*effectOnFullRegion=*/true,
+                           SideEffects::DefaultResource::get());
+    } else if (&opOperand == &getDestMutable()) {
       effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
-                         /*effectOnFullRegion=*/true,
-                         SideEffects::DefaultResource::get());
+                           /*effectOnFullRegion=*/true,
+                           SideEffects::DefaultResource::get());
       effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0,
-                         /*effectOnFullRegion=*/true,
-                         SideEffects::DefaultResource::get());
+                           /*effectOnFullRegion=*/true,
+                           SideEffects::DefaultResource::get());
     }
   }
 }
@@ -4843,16 +4842,15 @@ void UnPackOp::getEffects(
 
     if (&opOperand == &getSourceMutable()) {
       effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
-                         /*effectOnFullRegion=*/true,
-                         SideEffects::DefaultResource::get());
-    }
-    else if (&opOperand == &getDestMutable()) {
+                           /*effectOnFullRegion=*/true,
+                           SideEffects::DefaultResource::get());
+    } else if (&opOperand == &getDestMutable()) {
       effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
-                         /*effectOnFullRegion=*/true,
-                         SideEffects::DefaultResource::get());
+                           /*effectOnFullRegion=*/true,
+                           SideEffects::DefaultResource::get());
       effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0,
-                         /*effectOnFullRegion=*/true,
-                         SideEffects::DefaultResource::get());
+                           /*effectOnFullRegion=*/true,
+                           SideEffects::DefaultResource::get());
     }
   }
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 2ae6474cf3a2f..75afcb1fec332 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -360,7 +360,8 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
   rewriter.setInsertionPoint(unPackOp);
 
   // TODO: support non-ranked tensor types. ShapedType
-  RankedTensorType packedTensorType = dyn_cast<RankedTensorType>(unPackOp.getSourceType());
+  RankedTensorType packedTensorType =
+      dyn_cast<RankedTensorType>(unPackOp.getSourceType());
   int64_t packedRank = packedTensorType.getRank();
 
   OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);

>From 8b5ac5abd85b35ced34839b955247103341dd9a0 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Thu, 27 Mar 2025 14:21:30 +0900
Subject: [PATCH 18/33] make clang-format happy

---
 mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 3b91b897bcfd4..f716ff97f7cf3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1670,7 +1670,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   rewriter.setInsertionPoint(unpackOp);
 
   // TODO: support non-ranked tensor types. ShapedType
-  RankedTensorType unpackTensorType = dyn_cast<RankedTensorType>(unpackOp.getSourceType());
+  RankedTensorType unpackTensorType =
+      dyn_cast<RankedTensorType>(unpackOp.getSourceType());
 
   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
   ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();

>From c955d2137b454af779dedb12cd933da529140846 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Fri, 28 Mar 2025 07:34:26 +0900
Subject: [PATCH 19/33] wrap getEffects function

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 43 +++++++++---------------
 1 file changed, 15 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 7587178dd94d2..63977d7165e36 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4803,22 +4803,23 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
                           getPaddingValue(), metadata.outerDimsPerm);
 }
 
-void PackOp::getEffects(
-    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
-        &effects) {
+template <typename OpTy>
+static void getEffectsImpl(
+    OpTy op, SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+                 &effects) {
   // No memory effects for pure tensor semantics
-  if (hasPureTensorSemantics())
+  if (op.hasPureTensorSemantics())
     return;
 
-  for (OpOperand &opOperand : getOperation()->getOpOperands()) {
+  for (OpOperand &opOperand : op.getOperation()->getOpOperands()) {
     if (!llvm::isa<MemRefType>(opOperand.get().getType()))
       continue;
 
-    if (&opOperand == &getSourceMutable()) {
+    if (&opOperand == &op.getSourceMutable()) {
       effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
                            /*effectOnFullRegion=*/true,
                            SideEffects::DefaultResource::get());
-    } else if (&opOperand == &getDestMutable()) {
+    } else if (&opOperand == &op.getDestMutable()) {
       effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
                            /*effectOnFullRegion=*/true,
                            SideEffects::DefaultResource::get());
@@ -4829,30 +4830,16 @@ void PackOp::getEffects(
   }
 }
 
-void UnPackOp::getEffects(
+void PackOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  // No memory effects for pure tensor semantics
-  if (hasPureTensorSemantics())
-    return;
-
-  for (OpOperand &opOperand : getOperation()->getOpOperands()) {
-    if (!llvm::isa<MemRefType>(opOperand.get().getType()))
-      continue;
+  getEffectsImpl(*this, effects);
+}
 
-    if (&opOperand == &getSourceMutable()) {
-      effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
-                           /*effectOnFullRegion=*/true,
-                           SideEffects::DefaultResource::get());
-    } else if (&opOperand == &getDestMutable()) {
-      effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
-                           /*effectOnFullRegion=*/true,
-                           SideEffects::DefaultResource::get());
-      effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0,
-                           /*effectOnFullRegion=*/true,
-                           SideEffects::DefaultResource::get());
-    }
-  }
+void UnPackOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  getEffectsImpl(*this, effects);
 }
 
 /// Returns true if the tiles and the tiled dims are constant.

>From 276069d36b4bb88b628d2b29f20f6c85e76aa931 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Sun, 30 Mar 2025 08:47:51 +0900
Subject: [PATCH 20/33] fix upon review

---
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    |   9 +-
 .../mlir/Dialect/Utils/ReshapeOpsUtils.h      |   2 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 101 +++++++++++++-----
 .../Transforms/DataLayoutPropagation.cpp      |   4 +-
 mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp    |  12 +--
 mlir/test/Dialect/Linalg/roundtrip.mlir       |  10 +-
 6 files changed, 96 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 980e99872b9a6..bd9caa3f6b1a7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -190,7 +190,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     // Method to get the `RankedTensorType` of the result based on the inner
     // tiles, position of the inner tiles (innerDimsPos)  and interchange vector
     // of outer loops (outerDimsPerm).
-    static RankedTensorType inferPackedType(ShapedType sourceType,
+    static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
+        ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
+        ArrayRef<int64_t> outerDimsPerm = {});
+
+    // Method to get the `MemRefType` of the result based on the inner
+    // tiles, position of the inner tiles (innerDimsPos)  and interchange vector
+    // of outer loops (outerDimsPerm).
+    static MemRefType inferPackedMemRefType(MemRefType sourceType,
         ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
         ArrayRef<int64_t> outerDimsPerm = {});
 
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index a86bf74a7b6a1..99c80a2196567 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -573,7 +573,7 @@ PackingMetadata computePackingMetadata(int64_t packedRank,
 /// Removes the op and replaces the constant with a new constant of the result
 /// shape. When an optional cst attribute is passed, it is reshaped only if the
 /// splat value matches the value in the attribute.
-OpFoldResult reshapeConstantSource(DenseElementsAttr source, ShapedType result,
+OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result,
                                    std::optional<Attribute> cst = std::nullopt);
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index afff911168324..0af14b12da040 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -9,8 +9,8 @@
 // This file implements the Linalg operations.
 //
 //===----------------------------------------------------------------------===//
-
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include <iostream>
 
 #include "mlir/AsmParser/AsmParser.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -29,6 +29,7 @@
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Matchers.h"
@@ -45,6 +46,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Error.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/MathExtras.h"
@@ -4426,15 +4428,30 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
         tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
   };
 
+  // Verify that the source and destination are ranked types.
+  if (!packOrUnPack.getSourceType().hasRank() ||
+      !packOrUnPack.getDestType().hasRank()) {
+    return op->emitError(
+        "expected both source and destination to be shaped types");
+  }
+
   // Verify tiles. Do not allow zero tiles.
   SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
   if (hasZeros(mixedTiles))
     return op->emitError("invalid zero tile factor");
 
+  // Verify that the Operation does not have mixed tensor/buffer semantics.
+  if (!packOrUnPack.hasPureBufferSemantics() &&
+      !packOrUnPack.hasPureTensorSemantics()) {
+    return op->emitError("mixing tensor and buffer semantics is not allowed");
+  }
+  bool hasTensorSemantics = packOrUnPack.hasPureTensorSemantics();
+
   // Verify inner_dims_pos and outer_dims_perm.
   ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
                                 ? packOrUnPack.getSourceType()
                                 : packOrUnPack.getDestType();
+
   size_t unpackedRank = unpackedType.getRank();
   ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
   ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
@@ -4471,12 +4488,17 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
   // Verify result shape is greater than the minimum expected
   // by the pack operation, and that the output shape
   // represents full tiles.
-  RankedTensorType expectedPackedType = PackOp::inferPackedType(
-      unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
-  if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
-    return op->emitError("the shape of output is not large enough to hold the "
-                         "packed data. Expected at least ")
-           << expectedPackedType << ", got " << packedType;
+  if (hasTensorSemantics) {
+    RankedTensorType expectedPackedType = PackOp::inferPackedTensorType(
+        cast<RankedTensorType>(unpackedType), packOrUnPack.getStaticTiles(),
+        innerDimsPos, outerDimPerm);
+    if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
+      return op->emitError(
+                 "the shape of output is not large enough to hold the "
+                 "packed data. Expected at least ")
+             << expectedPackedType << ", got " << packedType;
+    }
+  } else {
   }
   if (!llvm::all_of(
           llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
@@ -4680,9 +4702,9 @@ asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
   return result;
 }
 
-/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
-/// the packed type. Having a shared helper helps implement these two methods in
-/// a way that ensures that they agree on which dimensions are dynamic.
+/// Helper for PackOp::{getResultShape,inferPackedTensorType}. Returns the shape
+/// of the packed type. Having a shared helper helps implement these two methods
+/// in a way that ensures that they agree on which dimensions are dynamic.
 static SmallVector<int64_t> getPackOpResultTypeShape(
     ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
     ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
@@ -4746,13 +4768,21 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
 
 /// Get the expected packed type based on source type, tile factors, position of
 /// the inner tiles and permutation of the outer tiled loop.
-RankedTensorType PackOp::inferPackedType(ShapedType sourceType,
+RankedTensorType PackOp::inferPackedTensorType(
+    RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
+    ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
+  SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
+      sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
+  return RankedTensorType::get(resultShape, sourceType.getElementType());
+}
+
+MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
                                          ArrayRef<int64_t> innerTileSizes,
                                          ArrayRef<int64_t> innerDimsPos,
                                          ArrayRef<int64_t> outerDimsPerm) {
   SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
       sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
-  return RankedTensorType::get(resultShape, sourceType.getElementType());
+  return MemRefType::get(resultShape, sourceType.getElementType());
 }
 
 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
@@ -4802,7 +4832,7 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
 }
 
 template <typename OpTy>
-static void getEffectsImpl(
+static void getPackUnPackEffectsImpl(
     OpTy op, SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
                  &effects) {
   // No memory effects for pure tensor semantics
@@ -4831,13 +4861,13 @@ static void getEffectsImpl(
 void PackOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getEffectsImpl(*this, effects);
+  getPackUnPackEffectsImpl(*this, effects);
 }
 
 void UnPackOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getEffectsImpl(*this, effects);
+  getPackUnPackEffectsImpl(*this, effects);
 }
 
 /// Returns true if the tiles and the tiled dims are constant.
@@ -4972,35 +5002,49 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     return success();
   }
 
-  // Insert tensor.cast ops if static shape inference is available..
+  // Insert either tensor.cast or memref.cast ops
+  // if static shape inference is available..
+  bool hasTensorSemantics = packOp.hasPureTensorSemantics();
+
   SmallVector<int64_t> srcShape, destShape;
   if (inferStaticShape(packOp, srcShape, destShape)) {
     Location loc = packOp.getLoc();
     Value source = packOp.getSource();
     if (srcShape != packOp.getSourceType().getShape()) {
       auto newSrcType = packOp.getSourceType().clone(srcShape);
-      source =
-          rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
+      if (hasTensorSemantics)
+        source = rewriter.create<tensor::CastOp>(loc, newSrcType,
+                                                 packOp.getSource());
+      else
+        source = rewriter.create<memref::CastOp>(loc, newSrcType,
+                                                 packOp.getSource());
     }
     Value dest = packOp.getDest();
     ShapedType originalResultType = packOp.getDestType();
     bool needUpdateDestType = (destShape != originalResultType.getShape());
     if (needUpdateDestType) {
       auto newDestType = packOp.getDestType().clone(destShape);
-      dest =
-          rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
+      if (hasTensorSemantics)
+        dest =
+            rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
     }
     rewriter.modifyOpInPlace(packOp, [&] {
       packOp.getSourceMutable().assign(source);
       packOp.getDestMutable().assign(dest);
-      packOp.getResult().setType(cast<ShapedType>(dest.getType()));
+      packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
     });
     // Insert a cast if needed
     if (needUpdateDestType) {
       rewriter.setInsertionPointAfter(packOp);
-      auto castOp =
-          rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
-      rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+      if (hasTensorSemantics) {
+        auto castOp =
+            rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
+        rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+      } else {
+        auto castOp =
+            rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
+        rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+      }
     }
     return success();
   }
@@ -5047,12 +5091,15 @@ bool PackOp::isLikePad() {
 }
 
 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
+  if (!hasPureTensorSemantics())
+    return {};
+
   std::optional<Attribute> paddingValue;
   if (auto pad = adaptor.getPaddingValue())
     paddingValue = pad;
   if (OpFoldResult reshapedSource = reshapeConstantSource(
           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
-          getDestType(), paddingValue))
+          cast<TensorType>(getDestType()), paddingValue))
     return reshapedSource;
   return {};
 }
@@ -5324,9 +5371,11 @@ bool UnPackOp::isLikeUnPad() {
 }
 
 OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
+  if (!hasPureTensorSemantics())
+    return {};
   if (OpFoldResult reshapedSource = reshapeConstantSource(
           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
-          getResult().getType()))
+          cast<TensorType>(getResult().getType())))
     return reshapedSource;
   return {};
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 9f5000b70b6f6..22bd5a8b38862 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -808,7 +808,7 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
   // If reassociation is not possible, then reordering cannot happen.
   // This can be caused by pack padding affecting previously expanded
   // dimensions or packing extending dimensions.
-  RankedTensorType newPackType = linalg::PackOp::inferPackedType(
+  RankedTensorType newPackType = linalg::PackOp::inferPackedTensorType(
       expandOp.getSrcType(), packOp.getStaticInnerTiles(),
       projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
   auto reassocExpand =
@@ -943,7 +943,7 @@ static LogicalResult pushDownUnPackOpThroughExpandShape(
     nextPos += 1;
   }
 
-  RankedTensorType newExpandType = linalg::PackOp::inferPackedType(
+  RankedTensorType newExpandType = linalg::PackOp::inferPackedTensorType(
       expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
   auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
       expandOp.getLoc(), newExpandType, unPackOp.getSource(),
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index cd0cdd378c352..86a1fb12f2b26 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -315,11 +315,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
     // have proven that these are not sliced. In this case we just take
     // the full extent of each dimension in the reassociation list.
     if (linearizedDimensions[it.index()]) {
-      llvm::append_range(
-          offsetsSizesAndStrides,
-          llvm::map_range(it.value(), [&](int64_t idx) -> Range {
-            return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
-          }));
+      llvm::append_range(offsetsSizesAndStrides,
+                         llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+                           return {zeroAttr, collapseShapeInputShape[idx],
+                                   oneAttr};
+                         }));
       continue;
     }
 
@@ -485,7 +485,7 @@ PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
 }
 
 OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
-                                         ShapedType result,
+                                         TensorType result,
                                          std::optional<Attribute> cst) {
   if (source && source.isSplat() && result.hasStaticShape() &&
       (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 7ca20f684583a..550d717570e69 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -708,17 +708,15 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt:
 // CHECK:         linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32>
 
 // -----
-// Test that we can lower all the way to LLVM without crashing, don't check results here.
-func.func @pack_memref(%source: memref<128x256xf32>, %dest: memref<8x16x8x32xf32>) -> memref<8x16x8x32xf32> {
+func.func @pack_memref(%source: memref<128x256xf32>, %dest: memref<8x16x8x32xf32>) {
   linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
       into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
-  return %dest : memref<8x16x8x32xf32>
+  return
 }
 
 // -----
-// Test that we can lower all the way to LLVM without crashing, don't check results here.
-func.func @unpack_memref(%source: memref<16x8x8x32xf32>, %dest: memref<128x256xf32>) -> memref<128x256xf32> {
+func.func @unpack_memref(%source: memref<16x8x8x32xf32>, %dest: memref<128x256xf32>) {
   linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
       into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
-  return %dest : memref<128x256xf32>
+  return
 }

>From 790e974e544fd8552cc668a621795f661b292247 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Sun, 30 Mar 2025 18:01:35 +0900
Subject: [PATCH 21/33] bail out transforms using PackOp, UnPackOp

---
 .../Linalg/Transforms/BlockPackMatmul.cpp     |  5 ++
 .../Transforms/DataLayoutPropagation.cpp      | 52 +++++++++++++++++++
 .../Linalg/Transforms/Vectorization.cpp       | 25 +++++++++
 3 files changed, 82 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 81842e4bea631..0b3d86d51ca0a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -91,6 +91,11 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
                       linalg::PackOp packOp, AffineMap operandMap,
                       ArrayRef<unsigned> blocksStartDimPos,
                       bool transposeOuterBlocks, bool transposeInnerBlocks) {
+  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+  if (!packOp.hasPureTensorSemantics()) {
+    return failure();
+  }
+
   assert(operandMap.getNumDims() >= 4 &&
          "expected at least 4D prepacked matmul");
   assert(blocksStartDimPos.size() >= 2 &&
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 22bd5a8b38862..ced3719ff8c3e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -63,6 +63,12 @@ getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
                           OpTy packOrUnPackOp) {
   static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
                 "applies to only pack or unpack operations");
+  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+  if (isa<linalg::LinalgOp>(packOrUnPackOp)) {
+    if (!packOrUnPackOp.hasPureTensorSemantics()) {
+      return failure();
+    }
+  }
   LLVM_DEBUG(
       { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
 
@@ -373,6 +379,11 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
 static FailureOr<GenericOp>
 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
                                const ControlPropagationFn &controlFn) {
+  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+  if (!packOp.hasPureTensorSemantics()) {
+    return failure();
+  }
+
   auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
   if (!genericOp)
     return failure();
@@ -461,6 +472,11 @@ struct BubbleUpPackOpThroughGenericOpPattern
 
   LogicalResult matchAndRewrite(linalg::PackOp packOp,
                                 PatternRewriter &rewriter) const override {
+    // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+    if (!packOp.hasPureTensorSemantics()) {
+      return failure();
+    }
+
     auto genericOp =
         bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
     if (failed(genericOp))
@@ -483,6 +499,11 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
 
   LogicalResult matchAndRewrite(linalg::PackOp packOp,
                                 PatternRewriter &rewriter) const override {
+    // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+    if (!packOp.hasPureTensorSemantics()) {
+      return failure();
+    }
+
     auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
     if (!padOp)
       return failure();
@@ -651,6 +672,11 @@ static LogicalResult
 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
                                    linalg::PackOp packOp,
                                    PatternRewriter &rewriter) {
+  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+  if (!packOp.hasPureTensorSemantics()) {
+    return failure();
+  }
+
   SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
   ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
   ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
@@ -757,6 +783,11 @@ static LogicalResult
 bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
                                  linalg::PackOp packOp,
                                  PatternRewriter &rewriter) {
+  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+  if (!packOp.hasPureTensorSemantics()) {
+    return failure();
+  }
+
   // Outer dimensions permutation is not supported currently.
   // TODO: Handle outer_dims_perm variants.
   ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
@@ -840,6 +871,11 @@ class BubbleUpPackOpThroughReshapeOp final
 
   LogicalResult matchAndRewrite(linalg::PackOp packOp,
                                 PatternRewriter &rewriter) const override {
+    // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+    if (!packOp.hasPureTensorSemantics()) {
+      return failure();
+    }
+
     Operation *srcOp = packOp.getSource().getDefiningOp();
     // Currently only support when the pack op is the only user.
     if (!srcOp || !(srcOp->getNumResults() == 1) ||
@@ -893,6 +929,11 @@ class BubbleUpPackOpThroughReshapeOp final
 static LogicalResult pushDownUnPackOpThroughExpandShape(
     linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
     PatternRewriter &rewriter, ControlPropagationFn controlFn) {
+  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+  if (!unPackOp.hasPureTensorSemantics()) {
+    return failure();
+  }
+
   // User controlled propagation function.
   if (!controlFn(&expandOp.getSrcMutable()))
     return failure();
@@ -970,6 +1011,11 @@ class PushDownUnPackOpThroughReshapeOp final
 
   LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp,
                                 PatternRewriter &rewriter) const override {
+    // TODO(issues/129004): Support MemRef UnPackOp. Temporarily return failure.
+    if (!unPackOp.hasPureTensorSemantics()) {
+      return failure();
+    }
+
     Value result = unPackOp.getResult();
     // Currently only support unpack op with the single user.
     if (!result.hasOneUse()) {
@@ -1146,11 +1192,17 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
 
   LogicalResult matchAndRewrite(tensor::PadOp padOp,
                                 PatternRewriter &rewriter) const override {
+
     linalg::UnPackOp unpackOp =
         padOp.getSource().getDefiningOp<linalg::UnPackOp>();
+
     if (!unpackOp)
       return failure();
 
+    // TODO(issues/129004): Support MemRef PadOp. Temporarily return failure.
+    if (!unpackOp.hasPureTensorSemantics())
+      return failure();
+
     if (!controlFn(&padOp.getSourceMutable()))
       return failure();
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f716ff97f7cf3..aba729ec3f5cd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1588,6 +1588,11 @@ static LogicalResult
 vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
                         ArrayRef<int64_t> inputVectorSizes,
                         SmallVectorImpl<Value> &newResults) {
+  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+  if (!packOp.hasPureTensorSemantics()) {
+    return failure();
+  }
+
   // TODO: Introduce a parent class that will handle the insertion point update.
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(packOp);
@@ -1664,6 +1669,10 @@ static LogicalResult
 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
                           ArrayRef<int64_t> inputVectorSizes,
                           SmallVectorImpl<Value> &newResults) {
+  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+  if (!unpackOp.hasPureTensorSemantics()) {
+    return failure();
+  }
 
   // TODO: Introduce a parent class that will handle the insertion point update.
   OpBuilder::InsertionGuard g(rewriter);
@@ -1891,6 +1900,10 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
 static LogicalResult
 vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
                               ArrayRef<int64_t> inputVectorSizes) {
+  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+  if (!unpackOp.hasPureTensorSemantics()) {
+    return failure();
+  }
 
   if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
         return !getConstantIntValue(res).has_value();
@@ -2136,6 +2149,11 @@ static LogicalResult vectorizeLinalgOpPrecondition(
 static LogicalResult
 vectorizePackOpPrecondition(linalg::PackOp packOp,
                             ArrayRef<int64_t> inputVectorSizes) {
+  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+  if (!packOp.hasPureTensorSemantics()) {
+    return failure();
+  }
+
   auto padValue = packOp.getPaddingValue();
   Attribute cstAttr;
   if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
@@ -2358,6 +2376,13 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
 }
 
 bool mlir::linalg::hasVectorizationImpl(Operation *op) {
+  // TODO(issues/129004): Support MemRef PackOp. Temporarily return false.
+  // Actually do we need this?
+  if (isa<linalg::PackOp, linalg::UnPackOp>(op)) {
+    if (!cast<LinalgOp>(op).hasPureTensorSemantics()) {
+      return false;
+    }
+  }
   return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
              tensor::InsertSliceOp>(op);
 }

>From 820e40b994b9b26b92c7f184b2b9a01c1328d489 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Sun, 30 Mar 2025 19:23:21 +0900
Subject: [PATCH 22/33] fix build error

---
 mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index ced3719ff8c3e..199011ac901ce 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -64,8 +64,9 @@ getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
   static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
                 "applies to only pack or unpack operations");
   // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
-  if (isa<linalg::LinalgOp>(packOrUnPackOp)) {
-    if (!packOrUnPackOp.hasPureTensorSemantics()) {
+  if (auto linalgOp =
+          dyn_cast<linalg::LinalgOp>(packOrUnPackOp.getOperation())) {
+    if (!linalgOp.hasPureTensorSemantics()) {
       return failure();
     }
   }

>From 43a64b912adaa2eed85d9715c13c3057c2c4b53e Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Sun, 30 Mar 2025 20:38:34 +0900
Subject: [PATCH 23/33] fix build error

---
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 25 +++++++++++++++++++
 .../Linalg/Transforms/Vectorization.cpp       |  7 ------
 2 files changed, 25 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 75afcb1fec332..63c0e4d126c9a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -219,6 +219,11 @@ struct PackedOperandsDimList {
 FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
                                              linalg::PackOp packOp,
                                              bool lowerPadLikeWithInsertSlice) {
+  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+  if (!packOp.hasPureTensorSemantics()) {
+    return failure();
+  }
+
   // 1. Filter out NYI cases.
   auto packedTensorType =
       cast<RankedTensorType>(packOp->getResultTypes().front());
@@ -355,6 +360,11 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
 FailureOr<LowerUnPackOpResult>
 linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
                     bool lowerUnpadLikeWithExtractSlice) {
+  // TODO(issues/129004): Support MemRef UnPackOp. Temporarily return failure.
+  if (!unPackOp.hasPureTensorSemantics()) {
+    return failure();
+  }
+
   Location loc = unPackOp->getLoc();
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unPackOp);
@@ -1032,6 +1042,11 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
     return input;
   }
 
+  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+  if (!packOp.hasPureTensorSemantics()) {
+    return packOp.getSource();
+  }
+
   assert(llvm::all_of(packOp.getAllOuterDims(),
                       [](int64_t val) { return val == 1; }) &&
          "some outer dims are != 1");
@@ -1144,6 +1159,11 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
 
 LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
     linalg::PackOp packOp, PatternRewriter &rewriter) const {
+  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+  if (!packOp.hasPureTensorSemantics()) {
+    return failure();
+  }
+
   // TODO: support the case that outer dimensions are not all 1s. A
   // tensor.expand_shape will be generated in this case.
   if (llvm::any_of(packOp.getAllOuterDims(),
@@ -1245,6 +1265,11 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
 
 LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
     linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
+  // TODO(issues/129004): Support MemRef UnPackOp. Temporarily return failure.
+  if (!unpackOp.hasPureTensorSemantics()) {
+    return failure();
+  }
+
   int64_t srcRank = unpackOp.getSourceRank();
   int64_t destRank = unpackOp.getDestRank();
   ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index aba729ec3f5cd..8936f9d9e389e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2376,13 +2376,6 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
 }
 
 bool mlir::linalg::hasVectorizationImpl(Operation *op) {
-  // TODO(issues/129004): Support MemRef PackOp. Temporarily return false.
-  // Actually do we need this?
-  if (isa<linalg::PackOp, linalg::UnPackOp>(op)) {
-    if (!cast<LinalgOp>(op).hasPureTensorSemantics()) {
-      return false;
-    }
-  }
   return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
              tensor::InsertSliceOp>(op);
 }

>From 486c62b7e91efca21f0aff37949095cac10b7895 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Wed, 2 Apr 2025 15:33:59 +0900
Subject: [PATCH 24/33] add invalid pack/unpack cases

---
 mlir/test/Dialect/Linalg/invalid.mlir | 37 +++++++++++++++++++++++++++
 1 file changed, 37 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 90ceadebbc1fa..aa12778ffbf7f 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1666,3 +1666,40 @@ func.func @unpack_static_inner_tile_size_and_dynamic_output_shape(
   %0 = linalg.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x?x4xf32> -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
+
+// -----
+
+func.func @pack_source_dest_type_mismatch_1(%source: tensor<128x256xf32>, %dest: memref<8x16x8x32xf32>) {
+  // expected-error at +1 {{mixing tensor and buffer semantics is not allowed}}
+  linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+      into %dest : tensor<128x256xf32> -> memref<8x16x8x32xf32>
+  return
+}
+
+
+// -----
+
+func.func @pack_source_dest_type_mismatch_2(%source: memref<128x256xf32>, %dest: tensor<8x16x8x32xf32>) {
+  // expected-error at +1 {{mixing tensor and buffer semantics is not allowed}}
+  %0 = linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+      into %dest : memref<128x256xf32> -> tensor<8x16x8x32xf32>
+  return
+}
+
+// -----
+
+func.func @unpack_source_dest_type_mismatch_1(%source: tensor<16x8x8x32xf32>, %dest: memref<128x256xf32>) {
+  // expected-error at +1 {{mixing tensor and buffer semantics is not allowed}}
+  linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+      into %dest : tensor<16x8x8x32xf32> -> memref<128x256xf32>
+  return
+}
+
+// -----
+
+func.func @unpack_source_dest_type_mismatch_1(%source: memref<16x8x8x32xf32>, %dest: tensor<128x256xf32>) {
+  // expected-error at +1 {{mixing tensor and buffer semantics is not allowed}}
+  %0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+      into %dest : memref<16x8x8x32xf32> -> tensor<128x256xf32>
+  return
+}

>From ca889b5727d5808902e360fec2ee2c0586b2a879 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Wed, 2 Apr 2025 16:41:06 +0900
Subject: [PATCH 25/33] fix roundtrip test

---
 mlir/test/Dialect/Linalg/roundtrip.mlir | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 550d717570e69..9c5141f56d575 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -708,15 +708,27 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt:
 // CHECK:         linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32>
 
 // -----
+
 func.func @pack_memref(%source: memref<128x256xf32>, %dest: memref<8x16x8x32xf32>) {
   linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
       into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
   return
 }
 
+// CHECK-label: func @pack_memref(
+// CHECK:   %[[source:[a-zA-z0-9]*]]: memref<128x256xf32>, %[[dest:[a-zA-z0-9]*]]: memref<8x16x8x32xf32>) {
+// CHECK:     %pack = linalg.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : memref<128x256xf32> -> memref<8x16x8x32xf32>
+// CHECK:   return
+// CHECK: }
 // -----
+
 func.func @unpack_memref(%source: memref<16x8x8x32xf32>, %dest: memref<128x256xf32>) {
   linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
       into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
   return
 }
+
+// CHECK-label: func @unpack_memref(
+// CHECK:   %[[source:[a-zA-z0-9]*]]: memref<16x8x8x32xf32>, %[[dest:[a-zA-z0-9]*]]: memref<128x256xf32>) {
+// CHECK:         %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : memref<16x8x8x32xf32> -> memref<128x256xf32>
+// CHECK:   return

>From ce910b9c8158a4b752394801b531ea46458a3e3c Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Wed, 2 Apr 2025 17:11:17 +0900
Subject: [PATCH 26/33] fix upon review

---
 .../Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp  |  2 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 32 +++++++------
 .../Linalg/Transforms/BlockPackMatmul.cpp     |  5 +-
 .../Transforms/DataLayoutPropagation.cpp      | 46 +++++++------------
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 26 +++++------
 .../Linalg/Transforms/Vectorization.cpp       | 23 ++++------
 6 files changed, 60 insertions(+), 74 deletions(-)

diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 6119097456d1f..bb2b474814824 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -78,7 +78,7 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
                                     omp::FlushOp, omp::MapBoundsOp,
                                     omp::ThreadprivateOp>::value) {
         if (isa<MemRefType>(originalOperand.getType())) {
-          // TODO: Support memref type in variable operands
+          // TODO: Support Memref PackOp. Temporarily return failure.
           return rewriter.notifyMatchFailure(op, "memref is not supported yet");
         }
       }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 7e3a714b95bc8..711b48abcc0f4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4431,8 +4431,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
   // Verify that the source and destination are ranked types.
   if (!packOrUnPack.getSourceType().hasRank() ||
       !packOrUnPack.getDestType().hasRank()) {
-    return op->emitError(
-        "expected both source and destination to be shaped types");
+    return op->emitError("expected both source and destination to have rank");
   }
 
   // Verify tiles. Do not allow zero tiles.
@@ -5002,31 +5001,26 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     return success();
   }
 
-  // Insert either tensor.cast or memref.cast ops
-  // if static shape inference is available..
+  // Insert tensor.cast if static shape inference is available..
   bool hasTensorSemantics = packOp.hasPureTensorSemantics();
 
+  // TODO: support memref.cast if static shape inference is available.
   SmallVector<int64_t> srcShape, destShape;
   if (inferStaticShape(packOp, srcShape, destShape)) {
     Location loc = packOp.getLoc();
     Value source = packOp.getSource();
     if (srcShape != packOp.getSourceType().getShape()) {
       auto newSrcType = packOp.getSourceType().clone(srcShape);
-      if (hasTensorSemantics)
-        source = rewriter.create<tensor::CastOp>(loc, newSrcType,
-                                                 packOp.getSource());
-      else
-        source = rewriter.create<memref::CastOp>(loc, newSrcType,
-                                                 packOp.getSource());
+      source =
+          rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
     }
     Value dest = packOp.getDest();
     ShapedType originalResultType = packOp.getDestType();
     bool needUpdateDestType = (destShape != originalResultType.getShape());
     if (needUpdateDestType) {
       auto newDestType = packOp.getDestType().clone(destShape);
-      if (hasTensorSemantics)
-        dest =
-            rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
+      dest =
+          rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
     }
     rewriter.modifyOpInPlace(packOp, [&] {
       packOp.getSourceMutable().assign(source);
@@ -5036,6 +5030,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     // Insert a cast if needed
     if (needUpdateDestType) {
       rewriter.setInsertionPointAfter(packOp);
+      /// 1
       if (hasTensorSemantics) {
         auto castOp =
             rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
@@ -5045,6 +5040,16 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
             rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
         rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
       }
+      /// 2
+      Operation *castOp;
+      if (hasTensorSemantics) {
+        castOp =
+            rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
+      } else {
+        castOp =
+            rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
+      }
+      rewriter.replaceAllUsesExcept(packOp, castOp->getResult(0), castOp);
     }
     return success();
   }
@@ -5126,6 +5131,7 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
     if (!tensor::hasFoldableTensorCastOperand(op))
       return failure();
 
+    // TODO: Support Memref PackOp. Temporarily return failure.
     if (!op.hasPureTensorSemantics())
       return failure();
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 0b3d86d51ca0a..cdd9d3da9bcf8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -91,10 +91,9 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
                       linalg::PackOp packOp, AffineMap operandMap,
                       ArrayRef<unsigned> blocksStartDimPos,
                       bool transposeOuterBlocks, bool transposeInnerBlocks) {
-  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
-  if (!packOp.hasPureTensorSemantics()) {
+  // TODO: Support Memref PackOp. Temporarily return failure.
+  if (!packOp.hasPureTensorSemantics())
     return failure();
-  }
 
   assert(operandMap.getNumDims() >= 4 &&
          "expected at least 4D prepacked matmul");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 199011ac901ce..54a11ad7c0b02 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -63,13 +63,16 @@ getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
                           OpTy packOrUnPackOp) {
   static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
                 "applies to only pack or unpack operations");
-  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
-  if (auto linalgOp =
-          dyn_cast<linalg::LinalgOp>(packOrUnPackOp.getOperation())) {
-    if (!linalgOp.hasPureTensorSemantics()) {
+  if (PackOp packOp = dyn_cast<PackOp>(packOrUnPackOp)) {
+    if (!packOp.hasPureTensorSemantics())
       return failure();
-    }
   }
+
+  if (UnPackOp unpackOp = dyn_cast<UnPackOp>(packOrUnPackOp)) {
+    if (!unpackOp.hasPureTensorSemantics())
+      return failure();
+  }
+
   LLVM_DEBUG(
       { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
 
@@ -380,10 +383,8 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
 static FailureOr<GenericOp>
 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
                                const ControlPropagationFn &controlFn) {
-  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
-  if (!packOp.hasPureTensorSemantics()) {
+  if (!packOp.hasPureTensorSemantics())
     return failure();
-  }
 
   auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
   if (!genericOp)
@@ -473,10 +474,8 @@ struct BubbleUpPackOpThroughGenericOpPattern
 
   LogicalResult matchAndRewrite(linalg::PackOp packOp,
                                 PatternRewriter &rewriter) const override {
-    // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
-    if (!packOp.hasPureTensorSemantics()) {
+    if (!packOp.hasPureTensorSemantics())
       return failure();
-    }
 
     auto genericOp =
         bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
@@ -500,10 +499,8 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
 
   LogicalResult matchAndRewrite(linalg::PackOp packOp,
                                 PatternRewriter &rewriter) const override {
-    // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
-    if (!packOp.hasPureTensorSemantics()) {
+    if (!packOp.hasPureTensorSemantics())
       return failure();
-    }
 
     auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
     if (!padOp)
@@ -673,10 +670,8 @@ static LogicalResult
 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
                                    linalg::PackOp packOp,
                                    PatternRewriter &rewriter) {
-  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
-  if (!packOp.hasPureTensorSemantics()) {
+  if (!packOp.hasPureTensorSemantics())
     return failure();
-  }
 
   SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
   ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
@@ -784,10 +779,8 @@ static LogicalResult
 bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
                                  linalg::PackOp packOp,
                                  PatternRewriter &rewriter) {
-  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
-  if (!packOp.hasPureTensorSemantics()) {
+  if (!packOp.hasPureTensorSemantics())
     return failure();
-  }
 
   // Outer dimensions permutation is not supported currently.
   // TODO: Handle outer_dims_perm variants.
@@ -872,10 +865,8 @@ class BubbleUpPackOpThroughReshapeOp final
 
   LogicalResult matchAndRewrite(linalg::PackOp packOp,
                                 PatternRewriter &rewriter) const override {
-    // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
-    if (!packOp.hasPureTensorSemantics()) {
+    if (!packOp.hasPureTensorSemantics())
       return failure();
-    }
 
     Operation *srcOp = packOp.getSource().getDefiningOp();
     // Currently only support when the pack op is the only user.
@@ -930,10 +921,8 @@ class BubbleUpPackOpThroughReshapeOp final
 static LogicalResult pushDownUnPackOpThroughExpandShape(
     linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
     PatternRewriter &rewriter, ControlPropagationFn controlFn) {
-  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
-  if (!unPackOp.hasPureTensorSemantics()) {
+  if (!unPackOp.hasPureTensorSemantics())
     return failure();
-  }
 
   // User controlled propagation function.
   if (!controlFn(&expandOp.getSrcMutable()))
@@ -1012,10 +1001,8 @@ class PushDownUnPackOpThroughReshapeOp final
 
   LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp,
                                 PatternRewriter &rewriter) const override {
-    // TODO(issues/129004): Support MemRef UnPackOp. Temporarily return failure.
-    if (!unPackOp.hasPureTensorSemantics()) {
+    if (!unPackOp.hasPureTensorSemantics())
       return failure();
-    }
 
     Value result = unPackOp.getResult();
     // Currently only support unpack op with the single user.
@@ -1200,7 +1187,6 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
     if (!unpackOp)
       return failure();
 
-    // TODO(issues/129004): Support MemRef PadOp. Temporarily return failure.
     if (!unpackOp.hasPureTensorSemantics())
       return failure();
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 63c0e4d126c9a..49a2dbed14e75 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -219,10 +219,9 @@ struct PackedOperandsDimList {
 FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
                                              linalg::PackOp packOp,
                                              bool lowerPadLikeWithInsertSlice) {
-  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
-  if (!packOp.hasPureTensorSemantics()) {
+  // TODO: Support Memref PackOp. Temporarily return failure.
+  if (!packOp.hasPureTensorSemantics())
     return failure();
-  }
 
   // 1. Filter out NYI cases.
   auto packedTensorType =
@@ -360,7 +359,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
 FailureOr<LowerUnPackOpResult>
 linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
                     bool lowerUnpadLikeWithExtractSlice) {
-  // TODO(issues/129004): Support MemRef UnPackOp. Temporarily return failure.
+  // TODO: Support Memref PackOp. Temporarily return failure.
   if (!unPackOp.hasPureTensorSemantics()) {
     return failure();
   }
@@ -369,9 +368,10 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unPackOp);
 
-  // TODO: support non-ranked tensor types. ShapedType
-  RankedTensorType packedTensorType =
-      dyn_cast<RankedTensorType>(unPackOp.getSourceType());
+  auto packedTensorType = dyn_cast<RankedTensorType>(unPackOp.getSourceType());
+  if (!packedTensorType)
+    return failure();
+
   int64_t packedRank = packedTensorType.getRank();
 
   OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -1042,10 +1042,9 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
     return input;
   }
 
-  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
-  if (!packOp.hasPureTensorSemantics()) {
+  // TODO: Support Memref PackOp. Temporarily return failure.
+  if (!packOp.hasPureTensorSemantics())
     return packOp.getSource();
-  }
 
   assert(llvm::all_of(packOp.getAllOuterDims(),
                       [](int64_t val) { return val == 1; }) &&
@@ -1159,10 +1158,9 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
 
 LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
     linalg::PackOp packOp, PatternRewriter &rewriter) const {
-  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
-  if (!packOp.hasPureTensorSemantics()) {
+  // TODO: Support Memref PackOp. Temporarily return failure.
+  if (!packOp.hasPureTensorSemantics())
     return failure();
-  }
 
   // TODO: support the case that outer dimensions are not all 1s. A
   // tensor.expand_shape will be generated in this case.
@@ -1265,7 +1263,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
 
 LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
     linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
-  // TODO(issues/129004): Support MemRef UnPackOp. Temporarily return failure.
+  // TODO: Support Memref PackOp. Temporarily return failure.
   if (!unpackOp.hasPureTensorSemantics()) {
     return failure();
   }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 8936f9d9e389e..c3d2de697efb4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1588,10 +1588,9 @@ static LogicalResult
 vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
                         ArrayRef<int64_t> inputVectorSizes,
                         SmallVectorImpl<Value> &newResults) {
-  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
-  if (!packOp.hasPureTensorSemantics()) {
+  // TODO: Support Memref PackOp. Temporarily return failure.
+  if (!packOp.hasPureTensorSemantics())
     return failure();
-  }
 
   // TODO: Introduce a parent class that will handle the insertion point update.
   OpBuilder::InsertionGuard g(rewriter);
@@ -1669,18 +1668,17 @@ static LogicalResult
 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
                           ArrayRef<int64_t> inputVectorSizes,
                           SmallVectorImpl<Value> &newResults) {
-  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
-  if (!unpackOp.hasPureTensorSemantics()) {
+  // TODO: Support Memref PackOp. Temporarily return failure.
+  if (!unpackOp.hasPureTensorSemantics())
     return failure();
-  }
 
   // TODO: Introduce a parent class that will handle the insertion point update.
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unpackOp);
 
-  // TODO: support non-ranked tensor types. ShapedType
-  RankedTensorType unpackTensorType =
-      dyn_cast<RankedTensorType>(unpackOp.getSourceType());
+  auto unpackTensorType = dyn_cast<RankedTensorType>(unpackOp.getSourceType());
+  if (!unpackTensorType)
+    return failure();
 
   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
   ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
@@ -1900,7 +1898,7 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
 static LogicalResult
 vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
                               ArrayRef<int64_t> inputVectorSizes) {
-  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
+  // TODO: Support Memref PackOp. Temporarily return failure.
   if (!unpackOp.hasPureTensorSemantics()) {
     return failure();
   }
@@ -2149,10 +2147,9 @@ static LogicalResult vectorizeLinalgOpPrecondition(
 static LogicalResult
 vectorizePackOpPrecondition(linalg::PackOp packOp,
                             ArrayRef<int64_t> inputVectorSizes) {
-  // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
-  if (!packOp.hasPureTensorSemantics()) {
+  // TODO: Support Memref PackOp. Temporarily return failure.
+  if (!packOp.hasPureTensorSemantics())
     return failure();
-  }
 
   auto padValue = packOp.getPaddingValue();
   Attribute cstAttr;

>From 6a501bdffa959eb1bf95a09c23b4699798879c9f Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Wed, 2 Apr 2025 18:49:02 +0900
Subject: [PATCH 27/33] fix upon review

---
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    |  7 ++++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 37 ++++++++-----------
 .../Transforms/DataLayoutPropagation.cpp      |  9 ++---
 3 files changed, 26 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index bd9caa3f6b1a7..b224b402ff8d6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -201,6 +201,13 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
         ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
         ArrayRef<int64_t> outerDimsPerm = {});
 
+    // Method to get the Shape of the result based on the input shape, inner
+    // tiles, position of the inner tiles (innerDimsPos)  and interchange vector
+    // of outer loops (outerDimsPerm).
+    static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
+        ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
+        ArrayRef<int64_t> outerDimsPerm = {});
+
     // Returns true if we have enough static information to catch undefined
     // behavior when the tile size does not divide perfectly the dimension of
     // the input tensor. Detecting UB requires that the input size and either
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 711b48abcc0f4..f285a5093a80b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4487,17 +4487,13 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
   // Verify result shape is greater than the minimum expected
   // by the pack operation, and that the output shape
   // represents full tiles.
-  if (hasTensorSemantics) {
-    RankedTensorType expectedPackedType = PackOp::inferPackedTensorType(
-        cast<RankedTensorType>(unpackedType), packOrUnPack.getStaticTiles(),
-        innerDimsPos, outerDimPerm);
-    if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
-      return op->emitError(
-                 "the shape of output is not large enough to hold the "
-                 "packed data. Expected at least ")
-             << expectedPackedType << ", got " << packedType;
-    }
-  } else {
+  auto expectedPackedShape = PackOp::inferPackedShape(
+      unpackedType.getShape(), packOrUnPack.getStaticTiles(),
+      packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
+  if (!areAllInBound(expectedPackedShape, packedType.getShape())) {
+    return op->emitError("the shape of output is not large enough to hold the "
+                         "packed data. Expected at least ")
+           << expectedPackedShape << ", got " << packedType.getShape();
   }
   if (!llvm::all_of(
           llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
@@ -4784,6 +4780,14 @@ MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
   return MemRefType::get(resultShape, sourceType.getElementType());
 }
 
+SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
+                                              ArrayRef<int64_t> innerTileSizes,
+                                              ArrayRef<int64_t> innerDimsPos,
+                                              ArrayRef<int64_t> outerDimsPerm) {
+  return getPackOpResultTypeShape(inputShape, innerTileSizes, innerDimsPos,
+                                  outerDimsPerm);
+}
+
 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
                                       ArrayRef<OpFoldResult> innerTileSizes,
                                       ArrayRef<int64_t> innerDimsPos,
@@ -5030,7 +5034,6 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     // Insert a cast if needed
     if (needUpdateDestType) {
       rewriter.setInsertionPointAfter(packOp);
-      /// 1
       if (hasTensorSemantics) {
         auto castOp =
             rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
@@ -5040,16 +5043,6 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
             rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
         rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
       }
-      /// 2
-      Operation *castOp;
-      if (hasTensorSemantics) {
-        castOp =
-            rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
-      } else {
-        castOp =
-            rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
-      }
-      rewriter.replaceAllUsesExcept(packOp, castOp->getResult(0), castOp);
     }
     return success();
   }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 54a11ad7c0b02..7891067323165 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -63,13 +63,12 @@ getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
                           OpTy packOrUnPackOp) {
   static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
                 "applies to only pack or unpack operations");
-  if (PackOp packOp = dyn_cast<PackOp>(packOrUnPackOp)) {
-    if (!packOp.hasPureTensorSemantics())
+  if constexpr (std::is_same_v<OpTy, linalg::PackOp>) {
+    if (!packOrUnPackOp.hasPureTensorSemantics())
       return failure();
   }
-
-  if (UnPackOp unpackOp = dyn_cast<UnPackOp>(packOrUnPackOp)) {
-    if (!unpackOp.hasPureTensorSemantics())
+  if constexpr (std::is_same_v<OpTy, linalg::UnPackOp>) {
+    if (!packOrUnPackOp.hasPureTensorSemantics())
       return failure();
   }
 

>From 535e796e458cd8c3f82e43177beb7ef1f507d918 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <hyunsungl at nvidia.com>
Date: Wed, 2 Apr 2025 20:04:27 +0900
Subject: [PATCH 28/33] .

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp               |  1 -
 .../Linalg/Transforms/DataLayoutPropagation.cpp        | 10 ++--------
 2 files changed, 2 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index f285a5093a80b..ea7ea0694e7a3 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4444,7 +4444,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
       !packOrUnPack.hasPureTensorSemantics()) {
     return op->emitError("mixing tensor and buffer semantics is not allowed");
   }
-  bool hasTensorSemantics = packOrUnPack.hasPureTensorSemantics();
 
   // Verify inner_dims_pos and outer_dims_perm.
   ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 7891067323165..9a5c792aea852 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -63,14 +63,8 @@ getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
                           OpTy packOrUnPackOp) {
   static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
                 "applies to only pack or unpack operations");
-  if constexpr (std::is_same_v<OpTy, linalg::PackOp>) {
-    if (!packOrUnPackOp.hasPureTensorSemantics())
-      return failure();
-  }
-  if constexpr (std::is_same_v<OpTy, linalg::UnPackOp>) {
-    if (!packOrUnPackOp.hasPureTensorSemantics())
-      return failure();
-  }
+  if (!packOrUnPackOp.hasPureTensorSemantics())
+    return failure();
 
   LLVM_DEBUG(
       { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });

>From 865d90c373fe914ec60bb4ad5990b3527db3aae3 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sun, 13 Apr 2025 12:32:44 +0900
Subject: [PATCH 29/33] fix upon review

---
 .../Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp  |   4 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  33 +-
 .../Transforms/DataLayoutPropagation.cpp      |   6 -
 .../Dialect/Linalg/Transforms/Transforms.cpp  |  23 +-
 .../Linalg/Transforms/Vectorization.cpp       |  12 +-
 mlir/test/Dialect/Linalg/canonicalize.mlir    | 166 +++++----
 mlir/test/Dialect/Linalg/invalid.mlir         | 340 +++++++++---------
 7 files changed, 301 insertions(+), 283 deletions(-)

diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index bb2b474814824..ad9621257f5df 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -77,10 +77,8 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
       if constexpr (llvm::is_one_of<T, omp::AtomicUpdateOp, omp::AtomicWriteOp,
                                     omp::FlushOp, omp::MapBoundsOp,
                                     omp::ThreadprivateOp>::value) {
-        if (isa<MemRefType>(originalOperand.getType())) {
-          // TODO: Support Memref PackOp. Temporarily return failure.
+        if (isa<MemRefType>(originalOperand.getType()))
           return rewriter.notifyMatchFailure(op, "memref is not supported yet");
-        }
       }
       convertedOperands.push_back(convertedOperand);
     }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ea7ea0694e7a3..f01e2f96e19d6 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -10,7 +10,6 @@
 //
 //===----------------------------------------------------------------------===//
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include <iostream>
 
 #include "mlir/AsmParser/AsmParser.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -4486,13 +4485,23 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
   // Verify result shape is greater than the minimum expected
   // by the pack operation, and that the output shape
   // represents full tiles.
-  auto expectedPackedShape = PackOp::inferPackedShape(
+  SmallVector<int64_t> expectedPackedShape = PackOp::inferPackedShape(
       unpackedType.getShape(), packOrUnPack.getStaticTiles(),
       packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
+
   if (!areAllInBound(expectedPackedShape, packedType.getShape())) {
+    auto elementType = unpackedType.getElementType();
+    Type expectedType, actualType;
+    if (packOrUnPack.hasPureTensorSemantics()) {
+      expectedType = RankedTensorType::get(expectedPackedShape, elementType);
+      actualType = RankedTensorType::get(packedType.getShape(), elementType);
+    } else {
+      expectedType = MemRefType::get(expectedPackedShape, elementType);
+      actualType = MemRefType::get(packedType.getShape(), elementType);
+    }
     return op->emitError("the shape of output is not large enough to hold the "
                          "packed data. Expected at least ")
-           << expectedPackedShape << ", got " << packedType.getShape();
+           << expectedType << ", got " << actualType;
   }
   if (!llvm::all_of(
           llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
@@ -5033,15 +5042,24 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     // Insert a cast if needed
     if (needUpdateDestType) {
       rewriter.setInsertionPointAfter(packOp);
+      // if (hasTensorSemantics) {
+      //   auto castOp =
+      //       rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
+      //   rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+      // } else {
+      //   auto castOp =
+      //       rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
+      //   rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+      // }
+      Operation *castOp;
       if (hasTensorSemantics) {
-        auto castOp =
+        castOp =
             rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
-        rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
       } else {
-        auto castOp =
+        castOp =
             rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
-        rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
       }
+      rewriter.replaceAllUsesExcept(packOp, castOp->getResult(0), castOp);
     }
     return success();
   }
@@ -5423,6 +5441,7 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
     if (!tensor::hasFoldableTensorCastOperand(op))
       return failure();
 
+    // TODO: Support Memref PackOp. Temporarily return failure.
     if (!op.hasPureTensorSemantics())
       return failure();
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 9a5c792aea852..5f38f5a84ac64 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -63,9 +63,6 @@ getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
                           OpTy packOrUnPackOp) {
   static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
                 "applies to only pack or unpack operations");
-  if (!packOrUnPackOp.hasPureTensorSemantics())
-    return failure();
-
   LLVM_DEBUG(
       { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
 
@@ -376,9 +373,6 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
 static FailureOr<GenericOp>
 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
                                const ControlPropagationFn &controlFn) {
-  if (!packOp.hasPureTensorSemantics())
-    return failure();
-
   auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
   if (!genericOp)
     return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 49a2dbed14e75..c8a930aec60cd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -219,7 +219,6 @@ struct PackedOperandsDimList {
 FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
                                              linalg::PackOp packOp,
                                              bool lowerPadLikeWithInsertSlice) {
-  // TODO: Support Memref PackOp. Temporarily return failure.
   if (!packOp.hasPureTensorSemantics())
     return failure();
 
@@ -359,19 +358,14 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
 FailureOr<LowerUnPackOpResult>
 linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
                     bool lowerUnpadLikeWithExtractSlice) {
-  // TODO: Support Memref PackOp. Temporarily return failure.
-  if (!unPackOp.hasPureTensorSemantics()) {
+  if (!unPackOp.hasPureTensorSemantics())
     return failure();
-  }
 
   Location loc = unPackOp->getLoc();
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unPackOp);
 
-  auto packedTensorType = dyn_cast<RankedTensorType>(unPackOp.getSourceType());
-  if (!packedTensorType)
-    return failure();
-
+  auto packedTensorType = cast<RankedTensorType>(unPackOp.getSourceType());
   int64_t packedRank = packedTensorType.getRank();
 
   OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -1038,14 +1032,14 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
 static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
                                            linalg::PackOp packOp) {
   Value input = packOp.getSource();
+  // TODO: Support Memref PackOp. Temporarily return just Op Source.
+  if (!packOp.hasPureTensorSemantics())
+    return input;
+
   if (!packOp.getPaddingValue()) {
     return input;
   }
 
-  // TODO: Support Memref PackOp. Temporarily return failure.
-  if (!packOp.hasPureTensorSemantics())
-    return packOp.getSource();
-
   assert(llvm::all_of(packOp.getAllOuterDims(),
                       [](int64_t val) { return val == 1; }) &&
          "some outer dims are != 1");
@@ -1158,7 +1152,6 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
 
 LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
     linalg::PackOp packOp, PatternRewriter &rewriter) const {
-  // TODO: Support Memref PackOp. Temporarily return failure.
   if (!packOp.hasPureTensorSemantics())
     return failure();
 
@@ -1263,10 +1256,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
 
 LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
     linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
-  // TODO: Support Memref PackOp. Temporarily return failure.
-  if (!unpackOp.hasPureTensorSemantics()) {
+  if (!unpackOp.hasPureTensorSemantics())
     return failure();
-  }
 
   int64_t srcRank = unpackOp.getSourceRank();
   int64_t destRank = unpackOp.getDestRank();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 89a61762e84f3..3a11921c50581 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1588,10 +1588,6 @@ static LogicalResult
 vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
                         ArrayRef<int64_t> inputVectorSizes,
                         SmallVectorImpl<Value> &newResults) {
-  // TODO: Support Memref PackOp. Temporarily return failure.
-  if (!packOp.hasPureTensorSemantics())
-    return failure();
-
   // TODO: Introduce a parent class that will handle the insertion point update.
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(packOp);
@@ -1668,17 +1664,11 @@ static LogicalResult
 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
                           ArrayRef<int64_t> inputVectorSizes,
                           SmallVectorImpl<Value> &newResults) {
-  // TODO: Support Memref PackOp. Temporarily return failure.
-  if (!unpackOp.hasPureTensorSemantics())
-    return failure();
-
   // TODO: Introduce a parent class that will handle the insertion point update.
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unpackOp);
 
-  auto unpackTensorType = dyn_cast<RankedTensorType>(unpackOp.getSourceType());
-  if (!unpackTensorType)
-    return failure();
+  auto unpackTensorType = cast<RankedTensorType>(unpackOp.getSourceType());
 
   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
   ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 86cb8f58abe02..eafbb99caecaa 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1,30 +1,42 @@
-// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file |
+// FileCheck %s
 
 // CHECK-LABEL: func @memref_cast(
-func.func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c8 = arith.constant 8 : index
-  %c16 = arith.constant 16 : index
-  %1 = memref.alloc (%b) : memref<?xi8>
-  %2 = memref.view %1[%c0][] : memref<?xi8> to memref<16x16xf32>
-  %3 = memref.cast %2 : memref<16x16xf32> to memref<?x?xf32>
-
-  // CHECK:  linalg.matmul ins({{.*}}memref<16x16xf32>, memref<16x16xf32>) outs({{.*}}memref<16x16xf32>)
-  linalg.matmul ins(%3, %3: memref<?x?xf32>, memref<?x?xf32>)
-               outs(%3: memref<?x?xf32>)
-  return %3: memref<?x?xf32>
-}
+func.func @memref_cast(% a : index, % b : index)->memref < ? x ? xf32>{
+  % c0 = arith.constant 0 : index %
+         c1 = arith.constant 1 : index %
+              c8 = arith.constant 8 : index %
+                   c16 = arith.constant 16 : index %
+                         1 = memref.alloc(% b) : memref <
+                                 ? xi8 > % 2 = memref.view % 1 [% c0][]
+                             : memref < ? xi8 > to memref<16x16xf32> %
+                                                    3 = memref.cast % 2
+                                        : memref<16x16xf32> to memref <
+                                 ? x
+                                 ? xf32 >
+
+                                       // CHECK:  linalg.matmul
+                                       // ins({{.*}}memref<16x16xf32>,
+                                       // memref<16x16xf32>)
+                                       // outs({{.*}}memref<16x16xf32>)
+                                       linalg.matmul ins(
+                                           % 3, % 3 : memref < ? x ? xf32 >,
+                                           memref < ? x
+                                           ? xf32 >)
+                                                 outs(% 3 : memref <
+                                                      ? x ? xf32 >) return % 3
+                                                          : memref <
+                                                      ? x
+                                                      ? xf32 > }
 
 // -----
 
 #accesses = [
-  affine_map<(i) -> (i)>
-]
+                                                        affine_map<(i)->(i)>]
 
 #trait = {
-  indexing_maps = #accesses,
-  iterator_types = ["parallel"]
+                                                      indexing_maps = #accesses,
+                                           iterator_types = ["parallel"]
 }
 
 func.func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
@@ -117,7 +129,7 @@ func.func @linalg_effects(
 
 // -----
 
-#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map = affine_map < (d0, d1, d2)->(d0, d1, d2)>
 func.func @remove_no_op(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>)
   -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
   %c0 = arith.constant 0 : index
@@ -144,7 +156,7 @@ func.func @remove_no_op(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>)
 
 // -----
 
-#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map = affine_map < (d0, d1, d2)->(d0, d1, d2)>
 func.func @remove_no_op_mismatched_types(%arg0 : tensor<?x?x?xf32>)
   -> tensor<1x2x3xf32> {
   %out = tensor.empty() : tensor<1x2x3xf32>
@@ -160,12 +172,12 @@ func.func @remove_no_op_mismatched_types(%arg0 : tensor<?x?x?xf32>)
 }
 // CHECK-LABEL: func @remove_no_op_mismatched_types
 //  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
-//       CHECK:     %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<1x2x3xf32>
-//       CHECK:     return %[[CAST]]
+//       CHECK:     %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to
+//       tensor<1x2x3xf32> CHECK:     return %[[CAST]]
 
 // -----
 
-#map = affine_map<() -> ()>
+#map = affine_map < ()->()>
 func.func @cant_fold_to_tensor_cast(%arg0 : f32) -> tensor<f32> {
   %out = tensor.empty() : tensor<f32>
   %g = linalg.generic {
@@ -183,7 +195,7 @@ func.func @cant_fold_to_tensor_cast(%arg0 : f32) -> tensor<f32> {
 
 // -----
 
-#map = affine_map<(d0, d1) -> (d0, d1)>
+#map = affine_map < (d0, d1)->(d0, d1)>
 func.func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -208,7 +220,7 @@ func.func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
 
 // -----
 
-#map = affine_map<(d0, d1) -> (d0, d1)>
+#map = affine_map < (d0, d1)->(d0, d1)>
 func.func @keep_not_noop(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>)
   -> (tensor<?x?xf32>, tensor<?x?xf32>) {
   %c0 = arith.constant 0 : index
@@ -386,7 +398,7 @@ func.func @fill_pack_general() -> tensor<1x1x8x4x4x8xi32>{
 
 // -----
 
-#map = affine_map<()[s0] -> (s0 ceildiv 16)>
+#map = affine_map < ()[s0]->(s0 ceildiv 16)>
 func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %c0 = arith.constant 0 : index
@@ -495,11 +507,15 @@ func.func @no_fold_pad_fill_value_mismatch() -> tensor<412x276xf32> {
 
 // -----
 
-// Tests below verify whether static information is propagated through all the operands of generic op.
-// 1. If one of the inputs of generic op has static info and it has no cast source.
-// 2. If one of the inputs of generic op has static info and it is coming from tensr.cast operation.
-// 3. If one of the outputs of generic op has static info and it is coming from tenso.cast operation.
-#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// Tests below verify whether static information is propagated through all the
+// operands of generic op.
+// 1. If one of the inputs of generic op has static info and it has no cast
+// source.
+// 2. If one of the inputs of generic op has static info and it is coming from
+// tensr.cast operation.
+// 3. If one of the outputs of generic op has static info and it is coming from
+// tenso.cast operation.
+#map = affine_map < (d0, d1, d2)->(d0, d1, d2)>
 // CHECK-LABEL: func @static_input_without_cast
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
 func.func @static_input_without_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
@@ -529,7 +545,7 @@ func.func @static_input_without_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<?x
 
 // -----
 
-#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map = affine_map < (d0, d1, d2)->(d0, d1, d2)>
 // CHECK-LABEL: func @static_input_with_cast
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
 func.func @static_input_with_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
@@ -560,7 +576,7 @@ func.func @static_input_with_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<?x?x?
 
 // -----
 
-#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map = affine_map < (d0, d1, d2)->(d0, d1, d2)>
 // CHECK-LABEL: func @static_output_with_cast
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>, %[[ARG2:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
 func.func @static_output_with_cast(%arg0 : tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
@@ -592,9 +608,9 @@ func.func @static_output_with_cast(%arg0 : tensor<?x?x?xf32>, %arg1: tensor<?x?x
 
 // -----
 
-// This test checks the folding of tensor.cast operation when the source value of cast
-// has more static information than the destination value.
-#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// This test checks the folding of tensor.cast operation when the source value
+// of cast has more static information than the destination value.
+#map = affine_map < (d0, d1, d2)->(d0, d1, d2)>
 // CHECK-LABEL: func @cast_source
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
 func.func @cast_source(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
@@ -625,7 +641,7 @@ func.func @cast_source(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> t
 
 // -----
 
-#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map = affine_map < (d0, d1, d2)->(d0, d1, d2)>
 // CHECK-LABEL: func @cast_dest
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<?x?x?xf32>, %[[ARG1:.*]]: tensor<1x?x?xf32>,
 func.func @cast_dest(%arg0: tensor<?x?x?xf32>, %arg1: tensor<1x?x?xf32>, %arg2: index, %arg3: index, %arg4: index) -> tensor<?x?x?xf32> {
@@ -649,8 +665,9 @@ func.func @cast_dest(%arg0: tensor<?x?x?xf32>, %arg1: tensor<1x?x?xf32>, %arg2:
 
 // -----
 
-#map = affine_map<(d0, d1) -> (d0, d1)>
-#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+#map = affine_map < (d0, d1)->(d0, d1)>
+#sparse = #sparse_tensor.encoding <                                            \
+          {map = (d0, d1)->(d0 : dense, d1 : compressed) }>
 // CHECK-DAG:   #[[$SPARSE:.+]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
 // CHECK-LABEL: func @static_shape_inference_with_encoding(
 // CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
@@ -839,23 +856,25 @@ func.func @linalgop_with_cond_cast_consumer(%arg0 : tensor<?x?xf32>, %arg1 : ten
   %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
       outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
   scf.if %arg3 {
-    %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x8xf32>
-    func.call @some_use(%1) : (tensor<4x8xf32>) -> ()
+    % 1 = tensor.cast % 0 : tensor < ? x ? xf32 > to tensor<4x8xf32> func.call
+                                                  @some_use(% 1)
+                                         : (tensor<4x8xf32>)->()
   }
-  return %0 : tensor<?x?xf32>
+  return % 0 : tensor < ? x ? xf32 >
 }
 
 // Check conditionally reachable cast is not folded into producer.
 // CHECK-LABEL: func @linalgop_with_cond_cast_consumer
-//  CHECK-SAME:     (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>, %[[ARG2:.*]]: tensor<?x?xf32>, %[[ARG3:.*]]: i1)
-//       CHECK: %[[RES:.*]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
+//  CHECK-SAME:     (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]:
+//  tensor<?x?xf32>, %[[ARG2:.*]]: tensor<?x?xf32>, %[[ARG3:.*]]: i1)
+//       CHECK: %[[RES:.*]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] :
+//       tensor<?x?xf32>, tensor<?x?xf32>)
 //  CHECK-SAME:      outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 //       CHECK: scf.if %[[ARG3]] {
-//       CHECK:   %[[CAST:.*]] = tensor.cast %[[RES]] : tensor<?x?xf32> to tensor<4x8xf32>
-//       CHECK:   func.call @some_use(%[[CAST]]) : (tensor<4x8xf32>) -> ()
-//       CHECK: }
-//       CHECK: return %[[RES]] : tensor<?x?xf32>
-
+//       CHECK:   %[[CAST:.*]] = tensor.cast %[[RES]] : tensor<?x?xf32> to
+//       tensor<4x8xf32> CHECK:   func.call @some_use(%[[CAST]]) :
+//       (tensor<4x8xf32>) -> () CHECK: } CHECK: return %[[RES]] :
+//       tensor<?x?xf32>
 
 // -----
 
@@ -904,17 +923,19 @@ func.func @fold_multi_use_generic_op_with_consumer(%arg0 : tensor<?x?x?xf32>) ->
 //       CHECK: func @fold_multi_use_generic_op_with_consumer
 //  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
 //   CHECK-DAG:   %[[INIT1:.+]] = tensor.empty() : tensor<2x3x4xf32>
-//   CHECK-DAG:   %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<4x3x2xf32>
-//   CHECK-DAG:   %[[INIT2:.+]] = tensor.empty() : tensor<3x2x4xf32>
+//   CHECK-DAG:   %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to
+//   tensor<4x3x2xf32> CHECK-DAG:   %[[INIT2:.+]] = tensor.empty() :
+//   tensor<3x2x4xf32>
 //       CHECK:   %[[GENERIC:.+]]:2 = linalg.generic
 //  CHECK-SAME:       ins(%[[CAST]] :
 //  CHECK-SAME:       outs(%[[INIT2]], %[[INIT1]] :
-//       CHECK:   %[[RETURN_CAST:.+]] = tensor.cast %[[GENERIC]]#0 : tensor<3x2x4xf32> to tensor<?x?x?xf32>
-//       CHECK:   return %[[RETURN_CAST]], %[[GENERIC]]#1
+//       CHECK:   %[[RETURN_CAST:.+]] = tensor.cast %[[GENERIC]]#0 :
+//       tensor<3x2x4xf32> to tensor<?x?x?xf32> CHECK:   return
+//       %[[RETURN_CAST]], %[[GENERIC]]#1
 
 // -----
 
-#map = affine_map<(d0) -> (d0)>
+#map = affine_map < (d0)->(d0)>
 func.func @identity_buffer(%arg0 : memref<?xf32>, %arg1: memref<?xf32>) {
   linalg.generic {
     indexing_maps = [#map, #map],
@@ -938,7 +959,7 @@ func.func @identity_buffer(%arg0 : memref<?xf32>, %arg1: memref<?xf32>) {
 
 // -----
 
-#map = affine_map<(d0, d1) -> (d1, d0)>
+#map = affine_map < (d0, d1)->(d1, d0)>
 func.func @erase_non_identity_noop(%arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = linalg.generic {
     indexing_maps = [#map, #map],
@@ -1722,6 +1743,21 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t
 
 // -----
 
+func.func @infer_and_fold_pack_unpack_same_tiles_memref(%t: memref<10x20x4x4xf32>) -> memref<10x20x4x4xf32> {
+  %c40 = arith.constant 40 : index
+  %c80 = arith.constant 80 : index
+  %buf_unpacked = memref.alloc() : memref<40x80xf32>
+  %unpacked = linalg.unpack %t inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %buf_unpacked : memref<10x20x4x4xf32> -> memref<40x80xf32>
+  %buf_packed = memref.alloc() : memref<10x20x4x4xf32>
+  %packed = linalg.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %buf_packed : memref<40x80xf32> -> memref<10x20x4x4xf32>
+  return %packed : memref<10x20x4x4xf32>
+}
+// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles_memref
+// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
+// CHECK:         return %[[SRC]]
+
+// -----
+
 // CHECK-LABEL:   func.func @pack_dont_drop_attributes(
 // CHECK: linalg.pack {{.*}}  {test_attr}
 func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {
@@ -1830,14 +1866,16 @@ func.func @no_fold_extract_slice_into_unpack_rank_reducing(
 func.func @no_fold_extract_slice_into_unpack_non_zero_offset(
     %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
 ) -> tensor<28x28xf32> {
-  %unpack = linalg.unpack %src
-      outer_dims_perm = [0, 1]
-      inner_dims_pos = [1]
-      inner_tiles = [16]
-      into %dest : tensor<28x2x16xf32> -> tensor<28x32xf32>
-  %extracted_slice = tensor.extract_slice %unpack
-      [0, 1] [28, 28] [1, 1] : tensor<28x32xf32> to tensor<28x28xf32>
-  return %extracted_slice : tensor<28x28xf32>
+  % unpack =
+      linalg.unpack % src outer_dims_perm =
+          [ 0, 1 ] inner_dims_pos = [1] inner_tiles =
+              [16] into % dest : tensor<28x2x16xf32>->tensor<28x32xf32> %
+              extracted_slice =
+                  tensor.extract_slice %
+                  unpack[0, 1][28, 28][1, 1] : tensor<28x32xf32> to
+                                                   tensor<28x28xf32> return %
+                                               extracted_slice
+      : tensor<28x28xf32>
 }
 
 // CHECK-LABEL: func @no_fold_extract_slice_into_unpack_non_zero_offset
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index aa12778ffbf7f..8177f1ee98584 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -32,90 +32,74 @@ func.func @index_parent() {
 // -----
 
 func.func @index_dim_lower_than_number_of_loops(%arg0: memref<f32>) {
-  // expected-error @+6 {{op expected dim (2) to be lower than the number of loops (0) of the enclosing LinalgOp}}
-  linalg.generic {
-      indexing_maps =  [ affine_map<() -> ()> ],
-      iterator_types = []}
-      outs(%arg0 : memref<f32>) {
-    ^bb(%0: f32):
-      linalg.index 2 : index
-      linalg.yield %0 : f32
+  // expected-error @+6 {{op expected dim (2) to be lower than the number of
+  // loops (0) of the enclosing LinalgOp}}
+  linalg.generic{indexing_maps = [affine_map<()->()>],
+                 iterator_types = []} outs(% arg0 : memref<f32>) {
+    ^bb(% 0 : f32) : linalg.index 2 : index linalg.yield % 0 : f32
   }
 }
 
 // -----
 
-func.func @index_dim_negative(%arg0: memref<f32>) {
-  // expected-error @+6 {{op attribute 'dim' failed to satisfy constraint: 64-bit signless integer attribute whose minimum value is 0}}
-  linalg.generic {
-      indexing_maps =  [ affine_map<() -> ()> ],
-      iterator_types = []}
-      outs(%arg0 : memref<f32>) {
-    ^bb(%0: f32):
-      linalg.index -1 : index
-      linalg.yield %0 : f32
+func.func @index_dim_negative(% arg0 : memref<f32>) {
+  // expected-error @+6 {{op attribute 'dim' failed to satisfy constraint:
+  // 64-bit signless integer attribute whose minimum value is 0}}
+  linalg.generic{indexing_maps = [affine_map<()->()>],
+                 iterator_types = []} outs(% arg0 : memref<f32>) {
+    ^bb(% 0 : f32) : linalg.index - 1 : index linalg.yield % 0 : f32
   }
 }
 
 // -----
 
-func.func @generic_no_region(%arg0: memref<f32>) {
-  // expected-error @+4 {{expected '{' to begin a region}}
-  linalg.generic {
-    indexing_maps =  [ affine_map<() -> (0)> ],
-    iterator_types = []
-  } ins(%arg0 : memref<f32>)
-}
+func.func @generic_no_region(% arg0 : memref<f32>){
+    // expected-error @+4 {{expected '{' to begin a region}}
+    linalg.generic{indexing_maps = [affine_map<()->(0)>],
+                   iterator_types = []} ins(% arg0 : memref<f32>)}
 
 // -----
 
-func.func @generic_mismatched_num_returns(%arg0: memref<f32>) {
-  // expected-error @+6 {{op expected number of yield values (0) to match the number of inits / outs operands of the enclosing LinalgOp (1)}}
-  linalg.generic {
-      indexing_maps =  [ affine_map<() -> ()> ],
-      iterator_types = []}
-      outs(%arg0 : memref<f32>) {
-    ^bb(%0: f32):
-      linalg.yield
+func.func @generic_mismatched_num_returns(% arg0 : memref<f32>) {
+  // expected-error @+6 {{op expected number of yield values (0) to match the
+  // number of inits / outs operands of the enclosing LinalgOp (1)}}
+  linalg.generic{indexing_maps = [affine_map<()->()>],
+                 iterator_types = []} outs(% arg0 : memref<f32>) {
+    ^bb(% 0 : f32) : linalg.yield
   }
 }
 
 // -----
 
-func.func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) {
-  // expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
-  linalg.generic {
-    indexing_maps =  [ affine_map<() -> (0)> ],
-    iterator_types = ["parallel"]}
-      outs(%arg0 : memref<1xi32>) {
-    ^bb(%i : i32):
-    linalg.yield %i : i32
+func.func @generic_wrong_dim_in_map(% arg0 : memref<1xi32>) {
+  // expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match
+  // the number of loops}}
+  linalg.generic{indexing_maps = [affine_map<()->(0)>],
+                 iterator_types = ["parallel"]} outs(% arg0 : memref<1xi32>) {
+    ^bb(% i : i32) : linalg.yield % i : i32
   }
 }
 
 // -----
 
-func.func @generic_wrong_iterator(%arg0: memref<1xi32>) {
+func.func @generic_wrong_iterator(% arg0 : memref<1xi32>) {
   // expected-error @+4 {{unexpected iterator_type (random)}}
-  linalg.generic {
-    indexing_maps =  [ affine_map<(i) -> (i)> ],
-    iterator_types = ["random"]}
-      outs(%arg0 : memref<1xi32>) {
-    ^bb(%i : i32):
-    linalg.yield %i : i32
+  linalg.generic{indexing_maps = [affine_map<(i)->(i)>],
+                 iterator_types = ["random"]} outs(% arg0 : memref<1xi32>) {
+    ^bb(% i : i32) : linalg.yield % i : i32
   }
 }
 
 // -----
 
 func.func @generic_one_d_view(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-  // expected-error @+1 {{expected operand rank (1) to match the result rank of indexing_map #0 (2)}}
+  // expected-error @+1 {{expected operand rank (1) to match the result rank of
+  // indexing_map #0 (2)}}
   linalg.generic {
     indexing_maps =  [ affine_map<() -> (0, 0)> ],
     iterator_types = []}
       outs(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-    ^bb(%f : f32):
-      linalg.yield %f: f32
+    ^bb(% f : f32) : linalg.yield % f : f32
   }
 }
 
@@ -129,22 +113,20 @@ func.func @generic_scalar_view(%arg0: memref<?xf32, affine_map<(i)[off]->(off +
     iterator_types = []}
       ins(%cst : f32)
       outs(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-    ^bb(%0 : f32, %1 : f32):
-      linalg.yield %0: f32
+    ^bb(% 0 : f32, % 1 : f32) : linalg.yield % 0 : f32
   }
 }
 
 // -----
 
 func.func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-  // expected-error @+7 {{'linalg.yield' op type of yield operand 1 ('i4') doesn't match the element type of the enclosing linalg.generic op ('f32')}}
+  // expected-error @+7 {{'linalg.yield' op type of yield operand 1 ('i4')
+  // doesn't match the element type of the enclosing linalg.generic op ('f32')}}
   linalg.generic {
     indexing_maps =  [ affine_map<(i) -> (i)> ],
     iterator_types = ["parallel"]}
       outs(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-    ^bb(%0: f32):
-      %1 = arith.constant 1: i4
-      linalg.yield %1: i4
+    ^bb(% 0 : f32) : % 1 = arith.constant 1 : i4 linalg.yield % 1 : i4
   }
 }
 
@@ -160,8 +142,7 @@ func.func @generic_singular_maps(%arg0: memref<?xf32, affine_map<(i)[off]->(off
     iterator_types = ["parallel","parallel"]}
     ins(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>)
    outs(%arg1 : memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-  ^bb(%0: f32, %1: f32):
-      linalg.yield %1: f32
+    ^bb(% 0 : f32, % 1 : f32) : linalg.yield % 1 : f32
   }
 }
 
@@ -171,57 +152,53 @@ func.func @generic_singular_maps(%arg0: memref<?xf32, affine_map<(i)[off]->(off
 
 // -----
 
-func.func @generic_empty_region(%arg0: memref<f32>) {
-  %f0 = arith.constant 0.0: f32
-  // expected-error @+1 {{op expects region #0 to have 0 or 1 blocks}}
-  linalg.generic {
-    indexing_maps =  [ affine_map<() -> ()>, affine_map<() -> ()> ],
-    iterator_types = []}
-      ins(%arg0 : memref<f32>)
-     outs(%arg0 : memref<f32>) {
-    ^bb1:
-      linalg.yield %f0: f32
-    ^bb2:
-      linalg.yield %f0: f32
+func.func @generic_empty_region(% arg0 : memref<f32>) {
+  % f0 = arith
+             .constant 0.0
+      : f32
+            // expected-error @+1 {{op expects region #0 to have 0 or 1 blocks}}
+            linalg.generic{indexing_maps =
+                               [ affine_map<()->()>, affine_map<()->()> ],
+                           iterator_types = []} ins(% arg0 : memref<f32>)
+                outs(% arg0 : memref<f32>) {
+    ^bb1 : linalg.yield % f0 : f32 ^ bb2 : linalg.yield % f0 : f32
   }
 }
 
 // -----
 
-func.func @generic_empty_region(%arg0: memref<f32>) {
-  %f0 = arith.constant 0.0: f32
-  // expected-error @+1 {{op expects to have 1 region with 1 block}}
-  linalg.generic {
-    indexing_maps =  [ affine_map<() -> ()> , affine_map<() -> ()> ],
-    iterator_types = []}
-    ins(%arg0 : memref<f32>)
-   outs(%arg0 : memref<f32>) {
-  }
+func.func @generic_empty_region(% arg0 : memref<f32>) {
+  % f0 = arith
+             .constant 0.0
+      : f32
+            // expected-error @+1 {{op expects to have 1 region with 1 block}}
+            linalg.generic{indexing_maps =
+                               [ affine_map<()->()>, affine_map<()->()> ],
+                           iterator_types = []} ins(% arg0 : memref<f32>)
+                outs(% arg0 : memref<f32>) {}
 }
 
 // -----
 
-func.func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
-  // expected-error @+6 {{'linalg.yield' op expected number of yield values (1) to match the number of inits / outs operands of the enclosing LinalgOp (2)}}
-  linalg.generic {
-      indexing_maps =  [ affine_map<() -> ()>, affine_map<() -> ()> ],
-      iterator_types = []}
-      outs(%arg0, %arg0 : memref<f32>, memref<f32>) {
-    ^bb(%f: f32):
-      linalg.yield %f: f32
+func.func @generic_mismatched_num_arguments(% arg0 : memref<f32>) {
+  // expected-error @+6 {{'linalg.yield' op expected number of yield values (1)
+  // to match the number of inits / outs operands of the enclosing LinalgOp
+  // (2)}}
+  linalg.generic{indexing_maps = [ affine_map<()->()>, affine_map<()->()> ],
+                 iterator_types = []} outs(% arg0, % arg0 : memref<f32>,
+                                           memref<f32>) {
+    ^bb(% f : f32) : linalg.yield % f : f32
   }
 }
 
 // -----
 
-func.func @generic_shaped_operand_block_arg_type(%arg0: memref<f32>) {
-  // expected-error @+6 {{'linalg.yield' op type of yield operand 1 ('i1') doesn't match the element type of the enclosing linalg.generic op ('f32')}}
-  linalg.generic {
-    indexing_maps =  [ affine_map<() -> ()> ],
-    iterator_types = []}
-      outs(%arg0 : memref<f32>) {
-    ^bb(%i: i1):
-    linalg.yield %i : i1
+func.func @generic_shaped_operand_block_arg_type(% arg0 : memref<f32>) {
+  // expected-error @+6 {{'linalg.yield' op type of yield operand 1 ('i1')
+  // doesn't match the element type of the enclosing linalg.generic op ('f32')}}
+  linalg.generic{indexing_maps = [affine_map<()->()>],
+                 iterator_types = []} outs(% arg0 : memref<f32>) {
+    ^bb(% i : i1) : linalg.yield % i : i1
   }
 }
 
@@ -241,14 +218,13 @@ func.func @generic_scalar_operand_block_arg_type(%arg0: tensor<f32>) {
 // -----
 
 func.func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-  // expected-error @+7 {{type of yield operand 1 ('i1') doesn't match the element type of the enclosing linalg.generic op ('f32')}}
+  // expected-error @+7 {{type of yield operand 1 ('i1') doesn't match the
+  // element type of the enclosing linalg.generic op ('f32')}}
   linalg.generic {
     indexing_maps = [ affine_map<(i) -> (i)> ],
     iterator_types = ["parallel"]}
       outs(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-    ^bb(%i: f32):
-      %0 = arith.constant 0: i1
-      linalg.yield %0: i1
+    ^bb(% i : f32) : % 0 = arith.constant 0 : i1 linalg.yield % 0 : i1
   }
 }
 
@@ -664,83 +640,96 @@ func.func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<
         iterator_types = ["parallel"]
 }
 
-func.func @invalid_reverse(%A: memref<5xf32>, %B: memref<5xf32>) {
-  // expected-error @+1 {{unexpected result less than 0 at expression #0 in}}
-  linalg.generic #attrs ins(%A: memref<5xf32>) outs(%B: memref<5xf32>) {
-                ^bb0(%a: f32, %b: f32):
-                linalg.yield %a : f32
-        }
-        return
-}
-
-// -----
-
-func.func @map_binary_wrong_yield_operands(
-    %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
-    -> tensor<64xf32> {
-   %add = linalg.map
-          ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
-          outs(%init:tensor<64xf32>)
-          (%lhs_elem: f32, %rhs_elem: f32) {
-            %0 = arith.addf %lhs_elem, %rhs_elem: f32
-            // expected-error @+1{{'linalg.yield' op expected number of yield values (2) to match the number of inits / outs operands of the enclosing LinalgOp (1)}}
-            linalg.yield %0, %0: f32, f32
-          }
-  func.return %add : tensor<64xf32>
-}
-
-// -----
-
-func.func @map_input_mapper_arity_mismatch(
-    %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
-    -> tensor<64xf32> {
-  // expected-error at +1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 2 and 3}}
-  %add = linalg.map
-      ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
-      outs(%init:tensor<64xf32>)
-      (%lhs_elem: f32, %rhs_elem: f32, %extra_elem: f32) {
-        %0 = arith.addf %lhs_elem, %rhs_elem: f32
-        linalg.yield %0: f32
-      }
-  func.return %add : tensor<64xf32>
-}
-
-// -----
-
-func.func @map_input_mapper_type_mismatch(
-    %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
-    -> tensor<64xf32> {
-    // expected-error at +1{{'linalg.map' op expected element type of input 'f32' to match bbArg type 'f64'}}
-  %add = linalg.map
-      ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
-      outs(%init:tensor<64xf32>)
-      (%lhs_elem: f64, %rhs_elem: f64) {
-        %0 = arith.addf %lhs_elem, %rhs_elem: f64
-        linalg.yield %0: f64
-      }
-  func.return %add : tensor<64xf32>
-}
-
-// -----
-
-func.func @map_input_output_shape_mismatch(
-    %lhs: tensor<64x64xf32>, %rhs: tensor<64x64xf32>, %init: tensor<32xf32>)
-    -> tensor<32xf32> {
-    // expected-error at +1{{'linalg.map' op expected shape of input (64, 64) to match shape of output (32)}}
-  %add = linalg.map
-      ins(%lhs, %rhs : tensor<64x64xf32>, tensor<64x64xf32>)
-      outs(%init:tensor<32xf32>)
-      (%lhs_elem: f32, %rhs_elem: f32) {
-        %0 = arith.addf %lhs_elem, %rhs_elem: f32
-        linalg.yield %0: f32
-      }
-  func.return %add : tensor<32xf32>
-}
+func.func @invalid_reverse(% A : memref<5xf32>, % B : memref<5xf32>){
+    // expected-error @+1 {{unexpected result less than 0 at expression #0 in}}
+    linalg.generic #attrs ins(% A : memref<5xf32>) outs(% B : memref<5xf32>){
+      ^bb0(% a : f32, % b : f32) : linalg.yield % a : f32
+    } return }
+
+// -----
+
+func.func @map_binary_wrong_yield_operands(% lhs : tensor<64xf32>,
+                                           % rhs : tensor<64xf32>,
+                                           % init : tensor<64xf32>)
+    ->tensor<64xf32>{
+      % add =
+          linalg
+              .map ins(% lhs, % rhs : tensor<64xf32>, tensor<64xf32>) outs(
+                  % init : tensor<64xf32>)(% lhs_elem : f32, % rhs_elem : f32){
+                % 0 = arith.addf % lhs_elem,
+                %
+                    rhs_elem : f32
+                        // expected-error @+1{{'linalg.yield' op expected number
+                        // of yield values (2) to match the number of inits /
+                        // outs operands of the enclosing LinalgOp (1)}}
+                            linalg.yield %
+                    0,
+                % 0 : f32,
+                f32
+              } func.return %
+          add : tensor<64xf32>
+    }
+
+// -----
+
+func.func @map_input_mapper_arity_mismatch(% lhs : tensor<64xf32>,
+                                           % rhs : tensor<64xf32>,
+                                           % init : tensor<64xf32>)
+    ->tensor<64xf32>{
+      // expected-error at +1{{'linalg.map' op expects number of operands to match
+      // the arity of mapper, but got: 2 and 3}}
+      % add = linalg
+                  .map ins(% lhs, % rhs : tensor<64xf32>, tensor<64xf32>)
+                      outs(% init : tensor<64xf32>)(% lhs_elem : f32,
+                                                    % rhs_elem : f32,
+                                                    % extra_elem : f32){
+                        % 0 = arith.addf % lhs_elem,
+                        % rhs_elem : f32 linalg.yield % 0 : f32
+                      } func.return %
+              add : tensor<64xf32>
+    }
+
+// -----
+
+func.func @map_input_mapper_type_mismatch(% lhs : tensor<64xf32>,
+                                          % rhs : tensor<64xf32>,
+                                          % init : tensor<64xf32>)
+    ->tensor<64xf32>{
+      // expected-error at +1{{'linalg.map' op expected element type of input 'f32'
+      // to match bbArg type 'f64'}}
+      % add = linalg
+                  .map ins(% lhs, % rhs : tensor<64xf32>, tensor<64xf32>)
+                      outs(% init : tensor<64xf32>)(% lhs_elem : f64,
+                                                    % rhs_elem : f64){
+                        % 0 = arith.addf % lhs_elem,
+                        % rhs_elem : f64 linalg.yield % 0 : f64
+                      } func.return %
+              add : tensor<64xf32>
+    }
+
+// -----
+
+func.func @map_input_output_shape_mismatch(% lhs : tensor<64x64xf32>,
+                                           % rhs : tensor<64x64xf32>,
+                                           % init : tensor<32xf32>)
+    ->tensor<32xf32>{
+      // expected-error at +1{{'linalg.map' op expected shape of input (64, 64) to
+      // match shape of output (32)}}
+      % add = linalg
+                  .map ins(% lhs, % rhs : tensor<64x64xf32>, tensor<64x64xf32>)
+                      outs(% init : tensor<32xf32>)(% lhs_elem : f32,
+                                                    % rhs_elem : f32){
+                        % 0 = arith.addf % lhs_elem,
+                        % rhs_elem : f32 linalg.yield % 0 : f32
+                      } func.return %
+              add : tensor<32xf32>
+    }
 
 // -----
 
 func.func @map_no_operands1() {
-  // expected-error @+1 {{'linalg.map' op expected 1 or more operands, but found 0}}
+  // expected-error @+1 {{'linalg.map' op expected 1 or more operands, but found
+  // 0}}
   linalg.map { arith.addf }
 }
 
@@ -1676,7 +1665,6 @@ func.func @pack_source_dest_type_mismatch_1(%source: tensor<128x256xf32>, %dest:
   return
 }
 
-
 // -----
 
 func.func @pack_source_dest_type_mismatch_2(%source: memref<128x256xf32>, %dest: tensor<8x16x8x32xf32>) {
@@ -1699,7 +1687,7 @@ func.func @unpack_source_dest_type_mismatch_1(%source: tensor<16x8x8x32xf32>, %d
 
 func.func @unpack_source_dest_type_mismatch_1(%source: memref<16x8x8x32xf32>, %dest: tensor<128x256xf32>) {
   // expected-error at +1 {{mixing tensor and buffer semantics is not allowed}}
-  %0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
-      into %dest : memref<16x8x8x32xf32> -> tensor<128x256xf32>
-  return
+  % 0 = linalg.unpack % source inner_dims_pos = [ 0, 1 ] inner_tiles =
+            [ 8, 32 ] into %
+            dest : memref<16x8x8x32xf32>->tensor<128x256xf32> return
 }

>From 021f88098a8649b2be30c068b8889a998a26362f Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sun, 13 Apr 2025 12:48:01 +0900
Subject: [PATCH 30/33] nit

---
 mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | 2 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp          | 9 ---------
 mlir/test/Dialect/Linalg/invalid.mlir             | 2 +-
 3 files changed, 2 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 99c80a2196567..3af89a6ab3799 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -451,7 +451,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
 ///    %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
 ///          tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
 ///
-///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
+///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] : 
 ///          tensor<1x1x1x10xf32> into tensor<1x10xf32>
 ///    %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
 ///          tensor<1x10xf32> into tensor<10x10xf32>
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index f01e2f96e19d6..2aff7b67ce6dd 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5042,15 +5042,6 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     // Insert a cast if needed
     if (needUpdateDestType) {
       rewriter.setInsertionPointAfter(packOp);
-      // if (hasTensorSemantics) {
-      //   auto castOp =
-      //       rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
-      //   rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
-      // } else {
-      //   auto castOp =
-      //       rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
-      //   rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
-      // }
       Operation *castOp;
       if (hasTensorSemantics) {
         castOp =
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 8177f1ee98584..852180aa28055 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -662,7 +662,7 @@ func.func @map_binary_wrong_yield_operands(% lhs : tensor<64xf32>,
                         // expected-error @+1{{'linalg.yield' op expected number
                         // of yield values (2) to match the number of inits /
                         // outs operands of the enclosing LinalgOp (1)}}
-                            linalg.yield %
+                        linalg.yield %
                     0,
                 % 0 : f32,
                 f32

>From 3115f4c096fae9e88f9ffcb71e6b26615b46c430 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sun, 13 Apr 2025 13:41:10 +0900
Subject: [PATCH 31/33] revert

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp |   9 +-
 mlir/test/Dialect/Linalg/invalid.mlir    | 342 ++++++++++++-----------
 2 files changed, 182 insertions(+), 169 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 2aff7b67ce6dd..82eb513ff940c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5013,10 +5013,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     return success();
   }
 
-  // Insert tensor.cast if static shape inference is available..
-  bool hasTensorSemantics = packOp.hasPureTensorSemantics();
-
-  // TODO: support memref.cast if static shape inference is available.
+  // Insert tensor.cast ops if static shape inference is available..
   SmallVector<int64_t> srcShape, destShape;
   if (inferStaticShape(packOp, srcShape, destShape)) {
     Location loc = packOp.getLoc();
@@ -5043,6 +5040,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     if (needUpdateDestType) {
       rewriter.setInsertionPointAfter(packOp);
       Operation *castOp;
+      bool hasTensorSemantics = packOp.hasPureTensorSemantics();
       if (hasTensorSemantics) {
         castOp =
             rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
@@ -5051,6 +5049,9 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
             rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
       }
       rewriter.replaceAllUsesExcept(packOp, castOp->getResult(0), castOp);
+    } else {
+      // TODO: support memref.cast if static shape inference is available.
+      return failure();
     }
     return success();
   }
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 852180aa28055..b25d81c71ae1f 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -32,74 +32,90 @@ func.func @index_parent() {
 // -----
 
 func.func @index_dim_lower_than_number_of_loops(%arg0: memref<f32>) {
-  // expected-error @+6 {{op expected dim (2) to be lower than the number of
-  // loops (0) of the enclosing LinalgOp}}
-  linalg.generic{indexing_maps = [affine_map<()->()>],
-                 iterator_types = []} outs(% arg0 : memref<f32>) {
-    ^bb(% 0 : f32) : linalg.index 2 : index linalg.yield % 0 : f32
+  // expected-error @+6 {{op expected dim (2) to be lower than the number of loops (0) of the enclosing LinalgOp}}
+  linalg.generic {
+      indexing_maps =  [ affine_map<() -> ()> ],
+      iterator_types = []}
+      outs(%arg0 : memref<f32>) {
+    ^bb(%0: f32):
+      linalg.index 2 : index
+      linalg.yield %0 : f32
   }
 }
 
 // -----
 
-func.func @index_dim_negative(% arg0 : memref<f32>) {
-  // expected-error @+6 {{op attribute 'dim' failed to satisfy constraint:
-  // 64-bit signless integer attribute whose minimum value is 0}}
-  linalg.generic{indexing_maps = [affine_map<()->()>],
-                 iterator_types = []} outs(% arg0 : memref<f32>) {
-    ^bb(% 0 : f32) : linalg.index - 1 : index linalg.yield % 0 : f32
+func.func @index_dim_negative(%arg0: memref<f32>) {
+  // expected-error @+6 {{op attribute 'dim' failed to satisfy constraint: 64-bit signless integer attribute whose minimum value is 0}}
+  linalg.generic {
+      indexing_maps =  [ affine_map<() -> ()> ],
+      iterator_types = []}
+      outs(%arg0 : memref<f32>) {
+    ^bb(%0: f32):
+      linalg.index -1 : index
+      linalg.yield %0 : f32
   }
 }
 
 // -----
 
-func.func @generic_no_region(% arg0 : memref<f32>){
-    // expected-error @+4 {{expected '{' to begin a region}}
-    linalg.generic{indexing_maps = [affine_map<()->(0)>],
-                   iterator_types = []} ins(% arg0 : memref<f32>)}
+func.func @generic_no_region(%arg0: memref<f32>) {
+  // expected-error @+4 {{expected '{' to begin a region}}
+  linalg.generic {
+    indexing_maps =  [ affine_map<() -> (0)> ],
+    iterator_types = []
+  } ins(%arg0 : memref<f32>)
+}
 
 // -----
 
-func.func @generic_mismatched_num_returns(% arg0 : memref<f32>) {
-  // expected-error @+6 {{op expected number of yield values (0) to match the
-  // number of inits / outs operands of the enclosing LinalgOp (1)}}
-  linalg.generic{indexing_maps = [affine_map<()->()>],
-                 iterator_types = []} outs(% arg0 : memref<f32>) {
-    ^bb(% 0 : f32) : linalg.yield
+func.func @generic_mismatched_num_returns(%arg0: memref<f32>) {
+  // expected-error @+6 {{op expected number of yield values (0) to match the number of inits / outs operands of the enclosing LinalgOp (1)}}
+  linalg.generic {
+      indexing_maps =  [ affine_map<() -> ()> ],
+      iterator_types = []}
+      outs(%arg0 : memref<f32>) {
+    ^bb(%0: f32):
+      linalg.yield
   }
 }
 
 // -----
 
-func.func @generic_wrong_dim_in_map(% arg0 : memref<1xi32>) {
-  // expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match
-  // the number of loops}}
-  linalg.generic{indexing_maps = [affine_map<()->(0)>],
-                 iterator_types = ["parallel"]} outs(% arg0 : memref<1xi32>) {
-    ^bb(% i : i32) : linalg.yield % i : i32
+func.func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) {
+  // expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
+  linalg.generic {
+    indexing_maps =  [ affine_map<() -> (0)> ],
+    iterator_types = ["parallel"]}
+      outs(%arg0 : memref<1xi32>) {
+    ^bb(%i : i32):
+    linalg.yield %i : i32
   }
 }
 
 // -----
 
-func.func @generic_wrong_iterator(% arg0 : memref<1xi32>) {
+func.func @generic_wrong_iterator(%arg0: memref<1xi32>) {
   // expected-error @+4 {{unexpected iterator_type (random)}}
-  linalg.generic{indexing_maps = [affine_map<(i)->(i)>],
-                 iterator_types = ["random"]} outs(% arg0 : memref<1xi32>) {
-    ^bb(% i : i32) : linalg.yield % i : i32
+  linalg.generic {
+    indexing_maps =  [ affine_map<(i) -> (i)> ],
+    iterator_types = ["random"]}
+      outs(%arg0 : memref<1xi32>) {
+    ^bb(%i : i32):
+    linalg.yield %i : i32
   }
 }
 
 // -----
 
 func.func @generic_one_d_view(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-  // expected-error @+1 {{expected operand rank (1) to match the result rank of
-  // indexing_map #0 (2)}}
+  // expected-error @+1 {{expected operand rank (1) to match the result rank of indexing_map #0 (2)}}
   linalg.generic {
     indexing_maps =  [ affine_map<() -> (0, 0)> ],
     iterator_types = []}
       outs(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-    ^bb(% f : f32) : linalg.yield % f : f32
+    ^bb(%f : f32):
+      linalg.yield %f: f32
   }
 }
 
@@ -113,20 +129,22 @@ func.func @generic_scalar_view(%arg0: memref<?xf32, affine_map<(i)[off]->(off +
     iterator_types = []}
       ins(%cst : f32)
       outs(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-    ^bb(% 0 : f32, % 1 : f32) : linalg.yield % 0 : f32
+    ^bb(%0 : f32, %1 : f32):
+      linalg.yield %0: f32
   }
 }
 
 // -----
 
 func.func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-  // expected-error @+7 {{'linalg.yield' op type of yield operand 1 ('i4')
-  // doesn't match the element type of the enclosing linalg.generic op ('f32')}}
+  // expected-error @+7 {{'linalg.yield' op type of yield operand 1 ('i4') doesn't match the element type of the enclosing linalg.generic op ('f32')}}
   linalg.generic {
     indexing_maps =  [ affine_map<(i) -> (i)> ],
     iterator_types = ["parallel"]}
       outs(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-    ^bb(% 0 : f32) : % 1 = arith.constant 1 : i4 linalg.yield % 1 : i4
+    ^bb(%0: f32):
+      %1 = arith.constant 1: i4
+      linalg.yield %1: i4
   }
 }
 
@@ -142,7 +160,8 @@ func.func @generic_singular_maps(%arg0: memref<?xf32, affine_map<(i)[off]->(off
     iterator_types = ["parallel","parallel"]}
     ins(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>)
    outs(%arg1 : memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-    ^bb(% 0 : f32, % 1 : f32) : linalg.yield % 1 : f32
+  ^bb(%0: f32, %1: f32):
+      linalg.yield %1: f32
   }
 }
 
@@ -152,53 +171,57 @@ func.func @generic_singular_maps(%arg0: memref<?xf32, affine_map<(i)[off]->(off
 
 // -----
 
-func.func @generic_empty_region(% arg0 : memref<f32>) {
-  % f0 = arith
-             .constant 0.0
-      : f32
-            // expected-error @+1 {{op expects region #0 to have 0 or 1 blocks}}
-            linalg.generic{indexing_maps =
-                               [ affine_map<()->()>, affine_map<()->()> ],
-                           iterator_types = []} ins(% arg0 : memref<f32>)
-                outs(% arg0 : memref<f32>) {
-    ^bb1 : linalg.yield % f0 : f32 ^ bb2 : linalg.yield % f0 : f32
+func.func @generic_empty_region(%arg0: memref<f32>) {
+  %f0 = arith.constant 0.0: f32
+  // expected-error @+1 {{op expects region #0 to have 0 or 1 blocks}}
+  linalg.generic {
+    indexing_maps =  [ affine_map<() -> ()>, affine_map<() -> ()> ],
+    iterator_types = []}
+      ins(%arg0 : memref<f32>)
+     outs(%arg0 : memref<f32>) {
+    ^bb1:
+      linalg.yield %f0: f32
+    ^bb2:
+      linalg.yield %f0: f32
   }
 }
 
 // -----
 
-func.func @generic_empty_region(% arg0 : memref<f32>) {
-  % f0 = arith
-             .constant 0.0
-      : f32
-            // expected-error @+1 {{op expects to have 1 region with 1 block}}
-            linalg.generic{indexing_maps =
-                               [ affine_map<()->()>, affine_map<()->()> ],
-                           iterator_types = []} ins(% arg0 : memref<f32>)
-                outs(% arg0 : memref<f32>) {}
+func.func @generic_empty_region(%arg0: memref<f32>) {
+  %f0 = arith.constant 0.0: f32
+  // expected-error @+1 {{op expects to have 1 region with 1 block}}
+  linalg.generic {
+    indexing_maps =  [ affine_map<() -> ()> , affine_map<() -> ()> ],
+    iterator_types = []}
+    ins(%arg0 : memref<f32>)
+   outs(%arg0 : memref<f32>) {
+  }
 }
 
 // -----
 
-func.func @generic_mismatched_num_arguments(% arg0 : memref<f32>) {
-  // expected-error @+6 {{'linalg.yield' op expected number of yield values (1)
-  // to match the number of inits / outs operands of the enclosing LinalgOp
-  // (2)}}
-  linalg.generic{indexing_maps = [ affine_map<()->()>, affine_map<()->()> ],
-                 iterator_types = []} outs(% arg0, % arg0 : memref<f32>,
-                                           memref<f32>) {
-    ^bb(% f : f32) : linalg.yield % f : f32
+func.func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
+  // expected-error @+6 {{'linalg.yield' op expected number of yield values (1) to match the number of inits / outs operands of the enclosing LinalgOp (2)}}
+  linalg.generic {
+      indexing_maps =  [ affine_map<() -> ()>, affine_map<() -> ()> ],
+      iterator_types = []}
+      outs(%arg0, %arg0 : memref<f32>, memref<f32>) {
+    ^bb(%f: f32):
+      linalg.yield %f: f32
   }
 }
 
 // -----
 
-func.func @generic_shaped_operand_block_arg_type(% arg0 : memref<f32>) {
-  // expected-error @+6 {{'linalg.yield' op type of yield operand 1 ('i1')
-  // doesn't match the element type of the enclosing linalg.generic op ('f32')}}
-  linalg.generic{indexing_maps = [affine_map<()->()>],
-                 iterator_types = []} outs(% arg0 : memref<f32>) {
-    ^bb(% i : i1) : linalg.yield % i : i1
+func.func @generic_shaped_operand_block_arg_type(%arg0: memref<f32>) {
+  // expected-error @+6 {{'linalg.yield' op type of yield operand 1 ('i1') doesn't match the element type of the enclosing linalg.generic op ('f32')}}
+  linalg.generic {
+    indexing_maps =  [ affine_map<() -> ()> ],
+    iterator_types = []}
+      outs(%arg0 : memref<f32>) {
+    ^bb(%i: i1):
+    linalg.yield %i : i1
   }
 }
 
@@ -218,13 +241,14 @@ func.func @generic_scalar_operand_block_arg_type(%arg0: tensor<f32>) {
 // -----
 
 func.func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-  // expected-error @+7 {{type of yield operand 1 ('i1') doesn't match the
-  // element type of the enclosing linalg.generic op ('f32')}}
+  // expected-error @+7 {{type of yield operand 1 ('i1') doesn't match the element type of the enclosing linalg.generic op ('f32')}}
   linalg.generic {
     indexing_maps = [ affine_map<(i) -> (i)> ],
     iterator_types = ["parallel"]}
       outs(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-    ^bb(% i : f32) : % 0 = arith.constant 0 : i1 linalg.yield % 0 : i1
+    ^bb(%i: f32):
+      %0 = arith.constant 0: i1
+      linalg.yield %0: i1
   }
 }
 
@@ -640,96 +664,83 @@ func.func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<
         iterator_types = ["parallel"]
 }
 
-func.func @invalid_reverse(% A : memref<5xf32>, % B : memref<5xf32>){
-    // expected-error @+1 {{unexpected result less than 0 at expression #0 in}}
-    linalg.generic #attrs ins(% A : memref<5xf32>) outs(% B : memref<5xf32>){
-      ^bb0(% a : f32, % b : f32) : linalg.yield % a : f32
-    } return }
-
-// -----
-
-func.func @map_binary_wrong_yield_operands(% lhs : tensor<64xf32>,
-                                           % rhs : tensor<64xf32>,
-                                           % init : tensor<64xf32>)
-    ->tensor<64xf32>{
-      % add =
-          linalg
-              .map ins(% lhs, % rhs : tensor<64xf32>, tensor<64xf32>) outs(
-                  % init : tensor<64xf32>)(% lhs_elem : f32, % rhs_elem : f32){
-                % 0 = arith.addf % lhs_elem,
-                %
-                    rhs_elem : f32
-                        // expected-error @+1{{'linalg.yield' op expected number
-                        // of yield values (2) to match the number of inits /
-                        // outs operands of the enclosing LinalgOp (1)}}
-                        linalg.yield %
-                    0,
-                % 0 : f32,
-                f32
-              } func.return %
-          add : tensor<64xf32>
-    }
-
-// -----
-
-func.func @map_input_mapper_arity_mismatch(% lhs : tensor<64xf32>,
-                                           % rhs : tensor<64xf32>,
-                                           % init : tensor<64xf32>)
-    ->tensor<64xf32>{
-      // expected-error at +1{{'linalg.map' op expects number of operands to match
-      // the arity of mapper, but got: 2 and 3}}
-      % add = linalg
-                  .map ins(% lhs, % rhs : tensor<64xf32>, tensor<64xf32>)
-                      outs(% init : tensor<64xf32>)(% lhs_elem : f32,
-                                                    % rhs_elem : f32,
-                                                    % extra_elem : f32){
-                        % 0 = arith.addf % lhs_elem,
-                        % rhs_elem : f32 linalg.yield % 0 : f32
-                      } func.return %
-              add : tensor<64xf32>
-    }
-
-// -----
-
-func.func @map_input_mapper_type_mismatch(% lhs : tensor<64xf32>,
-                                          % rhs : tensor<64xf32>,
-                                          % init : tensor<64xf32>)
-    ->tensor<64xf32>{
-      // expected-error at +1{{'linalg.map' op expected element type of input 'f32'
-      // to match bbArg type 'f64'}}
-      % add = linalg
-                  .map ins(% lhs, % rhs : tensor<64xf32>, tensor<64xf32>)
-                      outs(% init : tensor<64xf32>)(% lhs_elem : f64,
-                                                    % rhs_elem : f64){
-                        % 0 = arith.addf % lhs_elem,
-                        % rhs_elem : f64 linalg.yield % 0 : f64
-                      } func.return %
-              add : tensor<64xf32>
-    }
-
-// -----
-
-func.func @map_input_output_shape_mismatch(% lhs : tensor<64x64xf32>,
-                                           % rhs : tensor<64x64xf32>,
-                                           % init : tensor<32xf32>)
-    ->tensor<32xf32>{
-      // expected-error at +1{{'linalg.map' op expected shape of input (64, 64) to
-      // match shape of output (32)}}
-      % add = linalg
-                  .map ins(% lhs, % rhs : tensor<64x64xf32>, tensor<64x64xf32>)
-                      outs(% init : tensor<32xf32>)(% lhs_elem : f32,
-                                                    % rhs_elem : f32){
-                        % 0 = arith.addf % lhs_elem,
-                        % rhs_elem : f32 linalg.yield % 0 : f32
-                      } func.return %
-              add : tensor<32xf32>
-    }
+func.func @invalid_reverse(%A: memref<5xf32>, %B: memref<5xf32>) {
+  // expected-error @+1 {{unexpected result less than 0 at expression #0 in}}
+  linalg.generic #attrs ins(%A: memref<5xf32>) outs(%B: memref<5xf32>) {
+                ^bb0(%a: f32, %b: f32):
+                linalg.yield %a : f32
+        }
+        return
+}
+
+// -----
+
+func.func @map_binary_wrong_yield_operands(
+    %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
+    -> tensor<64xf32> {
+   %add = linalg.map
+          ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
+          outs(%init:tensor<64xf32>)
+          (%lhs_elem: f32, %rhs_elem: f32) {
+            %0 = arith.addf %lhs_elem, %rhs_elem: f32
+            // expected-error @+1{{'linalg.yield' op expected number of yield values (2) to match the number of inits / outs operands of the enclosing LinalgOp (1)}}
+            linalg.yield %0, %0: f32, f32
+          }
+  func.return %add : tensor<64xf32>
+}
+
+// -----
+
+func.func @map_input_mapper_arity_mismatch(
+    %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
+    -> tensor<64xf32> {
+  // expected-error at +1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 2 and 3}}
+  %add = linalg.map
+      ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
+      outs(%init:tensor<64xf32>)
+      (%lhs_elem: f32, %rhs_elem: f32, %extra_elem: f32) {
+        %0 = arith.addf %lhs_elem, %rhs_elem: f32
+        linalg.yield %0: f32
+      }
+  func.return %add : tensor<64xf32>
+}
+
+// -----
+
+func.func @map_input_mapper_type_mismatch(
+    %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
+    -> tensor<64xf32> {
+    // expected-error at +1{{'linalg.map' op expected element type of input 'f32' to match bbArg type 'f64'}}
+  %add = linalg.map
+      ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
+      outs(%init:tensor<64xf32>)
+      (%lhs_elem: f64, %rhs_elem: f64) {
+        %0 = arith.addf %lhs_elem, %rhs_elem: f64
+        linalg.yield %0: f64
+      }
+  func.return %add : tensor<64xf32>
+}
+
+// -----
+
+func.func @map_input_output_shape_mismatch(
+    %lhs: tensor<64x64xf32>, %rhs: tensor<64x64xf32>, %init: tensor<32xf32>)
+    -> tensor<32xf32> {
+    // expected-error at +1{{'linalg.map' op expected shape of input (64, 64) to match shape of output (32)}}
+  %add = linalg.map
+      ins(%lhs, %rhs : tensor<64x64xf32>, tensor<64x64xf32>)
+      outs(%init:tensor<32xf32>)
+      (%lhs_elem: f32, %rhs_elem: f32) {
+        %0 = arith.addf %lhs_elem, %rhs_elem: f32
+        linalg.yield %0: f32
+      }
+  func.return %add : tensor<32xf32>
+}
 
 // -----
 
 func.func @map_no_operands1() {
-  // expected-error @+1 {{'linalg.map' op expected 1 or more operands, but found
-  // 0}}
+  // expected-error @+1 {{'linalg.map' op expected 1 or more operands, but found 0}}
   linalg.map { arith.addf }
 }
 
@@ -1665,6 +1676,7 @@ func.func @pack_source_dest_type_mismatch_1(%source: tensor<128x256xf32>, %dest:
   return
 }
 
+
 // -----
 
 func.func @pack_source_dest_type_mismatch_2(%source: memref<128x256xf32>, %dest: tensor<8x16x8x32xf32>) {
@@ -1687,7 +1699,7 @@ func.func @unpack_source_dest_type_mismatch_1(%source: tensor<16x8x8x32xf32>, %d
 
 func.func @unpack_source_dest_type_mismatch_1(%source: memref<16x8x8x32xf32>, %dest: tensor<128x256xf32>) {
   // expected-error at +1 {{mixing tensor and buffer semantics is not allowed}}
-  % 0 = linalg.unpack % source inner_dims_pos = [ 0, 1 ] inner_tiles =
-            [ 8, 32 ] into %
-            dest : memref<16x8x8x32xf32>->tensor<128x256xf32> return
-}
+  %0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+      into %dest : memref<16x8x8x32xf32> -> tensor<128x256xf32>
+  return
+}
\ No newline at end of file

>From 4557fdeda980fabed1b0da75b52339ca1b4a93c4 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sun, 13 Apr 2025 13:43:41 +0900
Subject: [PATCH 32/33] revert

---
 mlir/test/Dialect/Linalg/canonicalize.mlir | 235 +++++----------------
 1 file changed, 55 insertions(+), 180 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index eafbb99caecaa..8ad008d8bbebd 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1,42 +1,30 @@
-// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file |
-// FileCheck %s
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @memref_cast(
-func.func @memref_cast(% a : index, % b : index)->memref < ? x ? xf32>{
-  % c0 = arith.constant 0 : index %
-         c1 = arith.constant 1 : index %
-              c8 = arith.constant 8 : index %
-                   c16 = arith.constant 16 : index %
-                         1 = memref.alloc(% b) : memref <
-                                 ? xi8 > % 2 = memref.view % 1 [% c0][]
-                             : memref < ? xi8 > to memref<16x16xf32> %
-                                                    3 = memref.cast % 2
-                                        : memref<16x16xf32> to memref <
-                                 ? x
-                                 ? xf32 >
-
-                                       // CHECK:  linalg.matmul
-                                       // ins({{.*}}memref<16x16xf32>,
-                                       // memref<16x16xf32>)
-                                       // outs({{.*}}memref<16x16xf32>)
-                                       linalg.matmul ins(
-                                           % 3, % 3 : memref < ? x ? xf32 >,
-                                           memref < ? x
-                                           ? xf32 >)
-                                                 outs(% 3 : memref <
-                                                      ? x ? xf32 >) return % 3
-                                                          : memref <
-                                                      ? x
-                                                      ? xf32 > }
+func.func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c8 = arith.constant 8 : index
+  %c16 = arith.constant 16 : index
+  %1 = memref.alloc (%b) : memref<?xi8>
+  %2 = memref.view %1[%c0][] : memref<?xi8> to memref<16x16xf32>
+  %3 = memref.cast %2 : memref<16x16xf32> to memref<?x?xf32>
+
+  // CHECK:  linalg.matmul ins({{.*}}memref<16x16xf32>, memref<16x16xf32>) outs({{.*}}memref<16x16xf32>)
+  linalg.matmul ins(%3, %3: memref<?x?xf32>, memref<?x?xf32>)
+               outs(%3: memref<?x?xf32>)
+  return %3: memref<?x?xf32>
+}
 
 // -----
 
 #accesses = [
-                                                        affine_map<(i)->(i)>]
+  affine_map<(i) -> (i)>
+]
 
 #trait = {
-                                                      indexing_maps = #accesses,
-                                           iterator_types = ["parallel"]
+  indexing_maps = #accesses,
+  iterator_types = ["parallel"]
 }
 
 func.func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
@@ -129,7 +117,7 @@ func.func @linalg_effects(
 
 // -----
 
-#map = affine_map < (d0, d1, d2)->(d0, d1, d2)>
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 func.func @remove_no_op(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>)
   -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
   %c0 = arith.constant 0 : index
@@ -156,7 +144,7 @@ func.func @remove_no_op(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>)
 
 // -----
 
-#map = affine_map < (d0, d1, d2)->(d0, d1, d2)>
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 func.func @remove_no_op_mismatched_types(%arg0 : tensor<?x?x?xf32>)
   -> tensor<1x2x3xf32> {
   %out = tensor.empty() : tensor<1x2x3xf32>
@@ -172,12 +160,12 @@ func.func @remove_no_op_mismatched_types(%arg0 : tensor<?x?x?xf32>)
 }
 // CHECK-LABEL: func @remove_no_op_mismatched_types
 //  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
-//       CHECK:     %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to
-//       tensor<1x2x3xf32> CHECK:     return %[[CAST]]
+//       CHECK:     %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<1x2x3xf32>
+//       CHECK:     return %[[CAST]]
 
 // -----
 
-#map = affine_map < ()->()>
+#map = affine_map<() -> ()>
 func.func @cant_fold_to_tensor_cast(%arg0 : f32) -> tensor<f32> {
   %out = tensor.empty() : tensor<f32>
   %g = linalg.generic {
@@ -195,7 +183,7 @@ func.func @cant_fold_to_tensor_cast(%arg0 : f32) -> tensor<f32> {
 
 // -----
 
-#map = affine_map < (d0, d1)->(d0, d1)>
+#map = affine_map<(d0, d1) -> (d0, d1)>
 func.func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -220,7 +208,7 @@ func.func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
 
 // -----
 
-#map = affine_map < (d0, d1)->(d0, d1)>
+#map = affine_map<(d0, d1) -> (d0, d1)>
 func.func @keep_not_noop(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>)
   -> (tensor<?x?xf32>, tensor<?x?xf32>) {
   %c0 = arith.constant 0 : index
@@ -398,7 +386,7 @@ func.func @fill_pack_general() -> tensor<1x1x8x4x4x8xi32>{
 
 // -----
 
-#map = affine_map < ()[s0]->(s0 ceildiv 16)>
+#map = affine_map<()[s0] -> (s0 ceildiv 16)>
 func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %c0 = arith.constant 0 : index
@@ -507,15 +495,11 @@ func.func @no_fold_pad_fill_value_mismatch() -> tensor<412x276xf32> {
 
 // -----
 
-// Tests below verify whether static information is propagated through all the
-// operands of generic op.
-// 1. If one of the inputs of generic op has static info and it has no cast
-// source.
-// 2. If one of the inputs of generic op has static info and it is coming from
-// tensr.cast operation.
-// 3. If one of the outputs of generic op has static info and it is coming from
-// tenso.cast operation.
-#map = affine_map < (d0, d1, d2)->(d0, d1, d2)>
+// Tests below verify whether static information is propagated through all the operands of generic op.
+// 1. If one of the inputs of generic op has static info and it has no cast source.
+// 2. If one of the inputs of generic op has static info and it is coming from tensr.cast operation.
+// 3. If one of the outputs of generic op has static info and it is coming from tenso.cast operation.
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-LABEL: func @static_input_without_cast
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
 func.func @static_input_without_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
@@ -545,7 +529,7 @@ func.func @static_input_without_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<?x
 
 // -----
 
-#map = affine_map < (d0, d1, d2)->(d0, d1, d2)>
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-LABEL: func @static_input_with_cast
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
 func.func @static_input_with_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
@@ -576,7 +560,7 @@ func.func @static_input_with_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<?x?x?
 
 // -----
 
-#map = affine_map < (d0, d1, d2)->(d0, d1, d2)>
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-LABEL: func @static_output_with_cast
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>, %[[ARG2:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
 func.func @static_output_with_cast(%arg0 : tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
@@ -608,9 +592,9 @@ func.func @static_output_with_cast(%arg0 : tensor<?x?x?xf32>, %arg1: tensor<?x?x
 
 // -----
 
-// This test checks the folding of tensor.cast operation when the source value
-// of cast has more static information than the destination value.
-#map = affine_map < (d0, d1, d2)->(d0, d1, d2)>
+// This test checks the folding of tensor.cast operation when the source value of cast
+// has more static information than the destination value.
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-LABEL: func @cast_source
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
 func.func @cast_source(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
@@ -641,7 +625,7 @@ func.func @cast_source(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> t
 
 // -----
 
-#map = affine_map < (d0, d1, d2)->(d0, d1, d2)>
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-LABEL: func @cast_dest
 // CHECK-SAME:  (%[[ARG0:.*]]: tensor<?x?x?xf32>, %[[ARG1:.*]]: tensor<1x?x?xf32>,
 func.func @cast_dest(%arg0: tensor<?x?x?xf32>, %arg1: tensor<1x?x?xf32>, %arg2: index, %arg3: index, %arg4: index) -> tensor<?x?x?xf32> {
@@ -665,34 +649,6 @@ func.func @cast_dest(%arg0: tensor<?x?x?xf32>, %arg1: tensor<1x?x?xf32>, %arg2:
 
 // -----
 
-#map = affine_map < (d0, d1)->(d0, d1)>
-#sparse = #sparse_tensor.encoding <                                            \
-          {map = (d0, d1)->(d0 : dense, d1 : compressed) }>
-// CHECK-DAG:   #[[$SPARSE:.+]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
-// CHECK-LABEL: func @static_shape_inference_with_encoding(
-// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
-func.func @static_shape_inference_with_encoding(%arg0: tensor<?x?xf32, #sparse>, %arg1: tensor<?x?xf32>) -> tensor<3x4xf32> {
-  %0 = tensor.empty() : tensor<3x4xf32>
-  %1 = linalg.generic {
-    indexing_maps = [#map, #map, #map],
-    iterator_types = ["parallel", "parallel"]
-  } ins(%arg0, %arg1 : tensor<?x?xf32, #sparse>, tensor<?x?xf32>)
-    outs(%0 : tensor<3x4xf32>) {
-  ^bb0(%in: f32, %in_0: f32, %out: f32):
-    %2 = arith.addf %in, %in_0 : f32
-    linalg.yield %2 : f32
-  } -> tensor<3x4xf32>
-  return %1 : tensor<3x4xf32>
-    //  CHECK:      %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?x?xf32, #[[$SPARSE]]> to tensor<3x4xf32, #[[$SPARSE]]>
-    //  CHECK-NEXT: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?xf32> to tensor<3x4xf32>
-    //  CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
-    //  CHECK-SAME: ins(%[[CAST_ARG0]], %[[CAST_ARG1]] : tensor<3x4xf32, #[[$SPARSE]]>, tensor<3x4xf32>)
-    //  CHECK-SAME: outs({{.*}} : tensor<3x4xf32>)
-}
-
-// -----
-
 //       CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 1)>
 // CHECK-LABEL: func @insert_pad_into_fill
 //  CHECK-SAME: (%[[INPUT:.+]]: tensor<?x?x?xf32>, %[[LOW0:.+]]: index, %[[LOW1:.+]]: index, %{{.+}}: index, %{{.+}}: index)
@@ -856,25 +812,23 @@ func.func @linalgop_with_cond_cast_consumer(%arg0 : tensor<?x?xf32>, %arg1 : ten
   %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
       outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
   scf.if %arg3 {
-    % 1 = tensor.cast % 0 : tensor < ? x ? xf32 > to tensor<4x8xf32> func.call
-                                                  @some_use(% 1)
-                                         : (tensor<4x8xf32>)->()
+    %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x8xf32>
+    func.call @some_use(%1) : (tensor<4x8xf32>) -> ()
   }
-  return % 0 : tensor < ? x ? xf32 >
+  return %0 : tensor<?x?xf32>
 }
 
 // Check conditionally reachable cast is not folded into producer.
 // CHECK-LABEL: func @linalgop_with_cond_cast_consumer
-//  CHECK-SAME:     (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]:
-//  tensor<?x?xf32>, %[[ARG2:.*]]: tensor<?x?xf32>, %[[ARG3:.*]]: i1)
-//       CHECK: %[[RES:.*]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] :
-//       tensor<?x?xf32>, tensor<?x?xf32>)
+//  CHECK-SAME:     (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>, %[[ARG2:.*]]: tensor<?x?xf32>, %[[ARG3:.*]]: i1)
+//       CHECK: %[[RES:.*]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
 //  CHECK-SAME:      outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 //       CHECK: scf.if %[[ARG3]] {
-//       CHECK:   %[[CAST:.*]] = tensor.cast %[[RES]] : tensor<?x?xf32> to
-//       tensor<4x8xf32> CHECK:   func.call @some_use(%[[CAST]]) :
-//       (tensor<4x8xf32>) -> () CHECK: } CHECK: return %[[RES]] :
-//       tensor<?x?xf32>
+//       CHECK:   %[[CAST:.*]] = tensor.cast %[[RES]] : tensor<?x?xf32> to tensor<4x8xf32>
+//       CHECK:   func.call @some_use(%[[CAST]]) : (tensor<4x8xf32>) -> ()
+//       CHECK: }
+//       CHECK: return %[[RES]] : tensor<?x?xf32>
+
 
 // -----
 
@@ -923,19 +877,17 @@ func.func @fold_multi_use_generic_op_with_consumer(%arg0 : tensor<?x?x?xf32>) ->
 //       CHECK: func @fold_multi_use_generic_op_with_consumer
 //  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
 //   CHECK-DAG:   %[[INIT1:.+]] = tensor.empty() : tensor<2x3x4xf32>
-//   CHECK-DAG:   %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to
-//   tensor<4x3x2xf32> CHECK-DAG:   %[[INIT2:.+]] = tensor.empty() :
-//   tensor<3x2x4xf32>
+//   CHECK-DAG:   %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<4x3x2xf32>
+//   CHECK-DAG:   %[[INIT2:.+]] = tensor.empty() : tensor<3x2x4xf32>
 //       CHECK:   %[[GENERIC:.+]]:2 = linalg.generic
 //  CHECK-SAME:       ins(%[[CAST]] :
 //  CHECK-SAME:       outs(%[[INIT2]], %[[INIT1]] :
-//       CHECK:   %[[RETURN_CAST:.+]] = tensor.cast %[[GENERIC]]#0 :
-//       tensor<3x2x4xf32> to tensor<?x?x?xf32> CHECK:   return
-//       %[[RETURN_CAST]], %[[GENERIC]]#1
+//       CHECK:   %[[RETURN_CAST:.+]] = tensor.cast %[[GENERIC]]#0 : tensor<3x2x4xf32> to tensor<?x?x?xf32>
+//       CHECK:   return %[[RETURN_CAST]], %[[GENERIC]]#1
 
 // -----
 
-#map = affine_map < (d0)->(d0)>
+#map = affine_map<(d0) -> (d0)>
 func.func @identity_buffer(%arg0 : memref<?xf32>, %arg1: memref<?xf32>) {
   linalg.generic {
     indexing_maps = [#map, #map],
@@ -959,7 +911,7 @@ func.func @identity_buffer(%arg0 : memref<?xf32>, %arg1: memref<?xf32>) {
 
 // -----
 
-#map = affine_map < (d0, d1)->(d1, d0)>
+#map = affine_map<(d0, d1) -> (d1, d0)>
 func.func @erase_non_identity_noop(%arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = linalg.generic {
     indexing_maps = [#map, #map],
@@ -1807,81 +1759,4 @@ func.func @fold_cast_unpack_dynamic_tile_size(
       inner_tiles = [%c8, 1]
       into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
     return %unpack : tensor<7x?xi32>
-}
-
-// -----
-
-//===----------------------------------------------------------------------===//
-// linalg.unpack + tensor.extract_slice
-//===----------------------------------------------------------------------===//
-
-func.func @fold_extract_slice_into_unpack(
-    %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
-) -> tensor<28x28x?xf32> {
-  %unpack = linalg.unpack %src
-      outer_dims_perm = [0, 1, 2]
-      inner_dims_pos = [1, 2]
-      inner_tiles = [16, 16]
-      into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
-  %extracted_slice = tensor.extract_slice %unpack
-      [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
-  return %extracted_slice : tensor<28x28x?xf32>
-}
-
-// CHECK-LABEL: func @fold_extract_slice_into_unpack
-//  CHECK-SAME:     %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
-//  CHECK-SAME:     %[[DEST:.+]]: tensor<28x32x?xf32>
-//  CHECK-SAME:     %[[SIZE:.+]]: index
-//       CHECK:   %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
-//  CHECK-SAME:     [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
-//       CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
-//  CHECK-SAME:       into %[[DEST_SLICE]]
-//       CHECK:   return %[[UNPACK]]
-
-// -----
-
-func.func @no_fold_extract_slice_into_unpack_rank_reducing(
-    %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
-) -> tensor<28xf32> {
-  %unpack = linalg.unpack %src
-      outer_dims_perm = [0, 1]
-      inner_dims_pos = [1]
-      inner_tiles = [16]
-      into %dest : tensor<28x2x16xf32> -> tensor<28x32xf32>
-  %extracted_slice = tensor.extract_slice %unpack
-      [0, 0] [1, 28] [1, 1] : tensor<28x32xf32> to tensor<28xf32>
-  return %extracted_slice : tensor<28xf32>
-}
-
-// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_rank_reducing
-//  CHECK-SAME:     %[[SRC:.+]]: tensor<28x2x16xf32>
-//  CHECK-SAME:     %[[DEST:.+]]: tensor<28x32xf32>
-//       CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
-//  CHECK-SAME:       into %[[DEST]]
-//       CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
-//       CHECK:   return %[[SLICE]]
-
-// -----
-
-func.func @no_fold_extract_slice_into_unpack_non_zero_offset(
-    %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
-) -> tensor<28x28xf32> {
-  % unpack =
-      linalg.unpack % src outer_dims_perm =
-          [ 0, 1 ] inner_dims_pos = [1] inner_tiles =
-              [16] into % dest : tensor<28x2x16xf32>->tensor<28x32xf32> %
-              extracted_slice =
-                  tensor.extract_slice %
-                  unpack[0, 1][28, 28][1, 1] : tensor<28x32xf32> to
-                                                   tensor<28x28xf32> return %
-                                               extracted_slice
-      : tensor<28x28xf32>
-}
-
-// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_non_zero_offset
-//  CHECK-SAME:     %[[SRC:.+]]: tensor<28x2x16xf32>
-//  CHECK-SAME:     %[[DEST:.+]]: tensor<28x32xf32>
-//       CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
-//  CHECK-SAME:       into %[[DEST]]
-//       CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
-//       CHECK:   return %[[SLICE]]
+}
\ No newline at end of file

>From 7d82d43aa628b92b8badffd72fe3d458c3b299a4 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Wed, 16 Apr 2025 05:47:53 +0900
Subject: [PATCH 33/33] fix

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 30 ++++++++----------------
 1 file changed, 10 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 82eb513ff940c..262e8e55c28cc 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4705,13 +4705,11 @@ asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
   return result;
 }
 
-/// Helper for PackOp::{getResultShape,inferPackedTensorType}. Returns the shape
-/// of the packed type. Having a shared helper helps implement these two methods
-/// in a way that ensures that they agree on which dimensions are dynamic.
-static SmallVector<int64_t> getPackOpResultTypeShape(
-    ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
-    ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
-  SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
+SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
+                                              ArrayRef<int64_t> innerTileSizes,
+                                              ArrayRef<int64_t> innerDimsPos,
+                                              ArrayRef<int64_t> outerDimsPerm) {
+  SmallVector<int64_t> resultShape = llvm::to_vector(inputShape);
   for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
     if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
       continue;
@@ -4751,9 +4749,9 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
   resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
 
   SmallVector<int64_t> resultTypeShape =
-      getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims),
-                               asShapeWithAnyValueAsDynamic(innerTileSizes),
-                               innerDimsPos, outerDimsPerm);
+      inferPackedShape(asShapeWithAnyValueAsDynamic(sourceDims),
+                       asShapeWithAnyValueAsDynamic(innerTileSizes),
+                       innerDimsPos, outerDimsPerm);
 
   // Fix-up `resultDims` to ensure that they are Value's if and only if the
   // result type shape says it's a dynamic dim. This is needed as callers may
@@ -4774,7 +4772,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
 RankedTensorType PackOp::inferPackedTensorType(
     RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
     ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
-  SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
+  SmallVector<int64_t> resultShape = inferPackedShape(
       sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
   return RankedTensorType::get(resultShape, sourceType.getElementType());
 }
@@ -4783,19 +4781,11 @@ MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
                                          ArrayRef<int64_t> innerTileSizes,
                                          ArrayRef<int64_t> innerDimsPos,
                                          ArrayRef<int64_t> outerDimsPerm) {
-  SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
+  SmallVector<int64_t> resultShape = inferPackedShape(
       sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
   return MemRefType::get(resultShape, sourceType.getElementType());
 }
 
-SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
-                                              ArrayRef<int64_t> innerTileSizes,
-                                              ArrayRef<int64_t> innerDimsPos,
-                                              ArrayRef<int64_t> outerDimsPerm) {
-  return getPackOpResultTypeShape(inputShape, innerTileSizes, innerDimsPos,
-                                  outerDimsPerm);
-}
-
 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
                                       ArrayRef<OpFoldResult> innerTileSizes,
                                       ArrayRef<int64_t> innerDimsPos,



More information about the Mlir-commits mailing list