[Mlir-commits] [mlir] 4b066c7 - [mlir][linalg] Extend linalg.pack and linalg.unpack to accept memref (#167675)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 19 07:42:34 PST 2026


Author: Ryutaro Okada
Date: 2026-01-19T16:42:27+01:00
New Revision: 4b066c7fff3455dc547fabb676583391febe41e9

URL: https://github.com/llvm/llvm-project/commit/4b066c7fff3455dc547fabb676583391febe41e9
DIFF: https://github.com/llvm/llvm-project/commit/4b066c7fff3455dc547fabb676583391febe41e9.diff

LOG: [mlir][linalg] Extend linalg.pack and linalg.unpack to accept memref (#167675)

Extend linalg.pack and linalg.unpack to accept memref operands in
addition to tensors. As part of this change, we now disable all
transformations when these ops have memref semantics.

Closes https://github.com/llvm/llvm-project/issues/129004

---------

Signed-off-by: Ryutaro Okada <1015ryu88 at gmail.com>
Co-authored-by: Hyunsung Lee <ita9naiwa at gmail.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
    mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
    mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
    mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/python/mlir/dialects/linalg/__init__.py
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir
    mlir/test/python/dialects/linalg/ops.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index f80c302ba7c51..95383e6262f71 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -7,11 +7,7 @@
 //===----------------------------------------------------------------------===//
 //
 // This file defines Pack + Unpack Ops that have been moved from the Tensor
-// dialect. As such, these are defined as memory-effect-free and only accept
-// "tensors" as inputs.
-//
-// TODO: Once a good motivating example is identified, relax these
-// restrictions.
+// dialect.
 //
 //===----------------------------------------------------------------------===//
 
@@ -30,24 +26,27 @@ include "mlir/IR/OpAsmInterface.td"
 // RelayoutOp
 //===----------------------------------------------------------------------===//
 
-class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
-      Op<Linalg_Dialect, mnemonic, !listconcat(traits, [
-        DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-        DestinationStyleOpInterface, LinalgRelayoutOpInterface,
-        ConditionallySpeculatable, NoMemoryEffect,
-        DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []>
+    : Op<Linalg_Dialect, mnemonic,
+         !listconcat(
+             traits, [DeclareOpInterfaceMethods<
+                          OpAsmOpInterface, ["getAsmResultNames"]>,
+                      DestinationStyleOpInterface, LinalgRelayoutOpInterface,
+                      ConditionallySpeculatable,
+                      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+                      DeclareOpInterfaceMethods<
+                          ReifyRankedShapedTypeOpInterface, [
           "reifyResultShapes"]>,
-        TypesMatchWith<"result type matches type of dest",
-                   "dest", "result",
-                   "$_self">])> {
+                      OptionalTypesMatchWith<"result type matches type of dest",
+                                             "dest", "result", "$_self">])> {
 
   code commonExtraClassDeclaration = [{
     size_t getSourceRank() { return getSourceType().getRank(); };
     size_t getDestRank() { return getDestType().getRank(); };
-    RankedTensorType getSourceType() {
-      return ::llvm::cast<RankedTensorType>(getSource().getType()); };
-    RankedTensorType getDestType() {
-      return ::llvm::cast<RankedTensorType>(getDest().getType()); };
+    ShapedType getSourceType() {
+      return ::llvm::cast<ShapedType>(getSource().getType()); };
+    ShapedType getDestType() {
+      return ::llvm::cast<ShapedType>(getDest().getType()); };
 
     MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
 
@@ -195,23 +194,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     //            expect tensor<2x8xf32> because CeilDiv(9, 8) = 2
     ```
   }];
-  let arguments = (ins AnyRankedTensor:$source,
-                       AnyRankedTensor:$dest,
-                       Optional<AnyType>:$padding_value,
-                       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
-                       DenseI64ArrayAttr:$inner_dims_pos,
-                       Variadic<Index>:$inner_tiles,
-                       DenseI64ArrayAttr:$static_inner_tiles);
-  let results = (outs AnyRankedTensor:$result);
-  let assemblyFormat = [{
-    $source
-    (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
-    (`outer_dims_perm` `=` $outer_dims_perm^)?
-    `inner_dims_pos` `=` $inner_dims_pos
-    `inner_tiles` `=`
-    custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
-    `into` $dest attr-dict `:` type($source) `->` type($dest)
-  }];
+  let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
+      TensorOrMemRef<[AnyType]>:$dest, 
+      Optional<AnyType>:$padding_value,
+      DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
+      DenseI64ArrayAttr:$inner_dims_pos, 
+      Variadic<Index>:$inner_tiles,
+      DenseI64ArrayAttr:$static_inner_tiles);
+  let results = (outs Optional<AnyRankedTensor>:$result);
 
   let builders = [
     OpBuilder<(ins "Value":$source, "Value":$dest,
@@ -233,7 +223,21 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     // Method to get the `RankedTensorType` of the result based on the inner
     // tiles, position of the inner tiles (innerDimsPos)  and interchange vector
     // of outer loops (outerDimsPerm).
-    static RankedTensorType inferPackedType(RankedTensorType sourceType,
+    static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
+        ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
+        ArrayRef<int64_t> outerDimsPerm = {});
+
+    // Method to get the `MemRefType` of the result based on the inner
+    // tiles, position of the inner tiles (innerDimsPos)  and interchange vector
+    // of outer loops (outerDimsPerm).
+    static MemRefType inferPackedMemRefType(MemRefType sourceType,
+        ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
+        ArrayRef<int64_t> outerDimsPerm = {});
+
+    // Returns the shape of the packed type. It is a shared helper that helps
+    // type inference methods in a way that ensures that they agree on which 
+    // dimensions are dynamic.
+    static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
         ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
         ArrayRef<int64_t> outerDimsPerm = {});
 
@@ -285,6 +289,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
   let hasCanonicalizeMethod = 1;
 
   let hasFolder = 1;
+
+  let hasCustomAssemblyFormat = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -352,21 +358,12 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
     //          Outer Dims: 9x3x8   Inner Dims: 4x2
     ```
   }];
-  let arguments = (ins AnyRankedTensor:$source,
-                       AnyRankedTensor:$dest,
-                       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
-                       DenseI64ArrayAttr:$inner_dims_pos,
-                       Variadic<Index>:$inner_tiles,
-                       DenseI64ArrayAttr:$static_inner_tiles);
-  let results = (outs AnyRankedTensor:$result);
-  let assemblyFormat = [{
-    $source
-    (`outer_dims_perm` `=` $outer_dims_perm^)?
-    `inner_dims_pos` `=` $inner_dims_pos
-    `inner_tiles` `=`
-    custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
-    `into` $dest attr-dict `:` type($source) `->` type($dest)
-  }];
+  let arguments = (ins TensorOrMemRef<[AnyType]>:$source,
+      TensorOrMemRef<[AnyType]>:$dest,
+      DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
+      DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
+      DenseI64ArrayAttr:$static_inner_tiles);
+  let results = (outs Optional<AnyRankedTensor>:$result);
 
   let builders = [
     OpBuilder<(ins "Value":$source, "Value":$dest,
@@ -409,6 +406,8 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
   let hasCanonicalizeMethod = 1;
 
   let hasFolder = 1;
+
+  let hasCustomAssemblyFormat = 1;
 }
 
 #endif // LINALG_RELEAYOUT_OPS

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e9f617a785d22..40cabe20d1a4b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4970,12 +4970,12 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
 template <typename OpTy, typename>
 SmallVector<int64_t>
 getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
-  RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
-                                    ? packOrUnPack.getDestType()
-                                    : packOrUnPack.getSourceType();
-  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
-                                      ? packOrUnPack.getSourceType()
-                                      : packOrUnPack.getDestType();
+  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
+                              ? packOrUnPack.getDestType()
+                              : packOrUnPack.getSourceType();
+  ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
+                                ? packOrUnPack.getSourceType()
+                                : packOrUnPack.getDestType();
   SmallVector<int64_t> result(
       packedType.getShape().take_front(unpackedType.getRank()));
   if (!packOrUnPack.getOuterDimsPerm().empty()) {
@@ -5109,15 +5109,30 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
     return llvm::any_of(tiles, isZeroInteger);
   };
 
+  // Verify that the source and destination are ranked types.
+  if (!packOrUnPack.getSourceType().hasRank() ||
+      !packOrUnPack.getDestType().hasRank())
+    return op->emitError("expected both source and destination to have rank");
+
+  // Verify that the Operation does not have mixed tensor/buffer semantics.
+  if (!packOrUnPack.hasPureBufferSemantics() &&
+      !packOrUnPack.hasPureTensorSemantics())
+    return op->emitError("mixing tensor and buffer semantics is not allowed");
+  const unsigned numResults = packOrUnPack.getNumResults();
+  if (packOrUnPack.hasPureTensorSemantics() && numResults != 1)
+    return op->emitError("expected 1 result, got ") << numResults;
+  if (packOrUnPack.hasPureBufferSemantics() && numResults != 0)
+    return op->emitError("expected 0 results, got ") << numResults;
+
   // Verify tiles. Do not allow zero tiles.
   SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
   if (hasZeros(mixedTiles))
     return op->emitError("invalid zero tile factor");
 
   // Verify inner_dims_pos and outer_dims_perm.
-  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
-                                      ? packOrUnPack.getSourceType()
-                                      : packOrUnPack.getDestType();
+  ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
+                                ? packOrUnPack.getSourceType()
+                                : packOrUnPack.getDestType();
   size_t unpackedRank = unpackedType.getRank();
   ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
   ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
@@ -5154,8 +5169,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
   // Verify result shape is greater than the minimum expected
   // by the pack operation, and that the output shape
   // represents full tiles.
-  RankedTensorType expectedPackedType = PackOp::inferPackedType(
-      unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
+  SmallVector<int64_t> expectedPackedShape = PackOp::inferPackedShape(
+      unpackedType.getShape(), packOrUnPack.getStaticTiles(),
+      packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
   if (!llvm::all_of(
           llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
                     mixedTiles),
@@ -5172,11 +5188,20 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
     return op->emitError("mismatch in inner tile sizes specified and shaped of "
                          "tiled dimension in the packed type");
   }
-  if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
-                                   packedType.getShape()))) {
+  if (failed(
+          verifyCompatibleShape(expectedPackedShape, packedType.getShape()))) {
+    auto elementType = unpackedType.getElementType();
+    Type expectedType, actualType;
+    if (packOrUnPack.hasPureTensorSemantics()) {
+      expectedType = RankedTensorType::get(expectedPackedShape, elementType);
+      actualType = RankedTensorType::get(packedType.getShape(), elementType);
+    } else {
+      expectedType = MemRefType::get(expectedPackedShape, elementType);
+      actualType = MemRefType::get(packedType.getShape(), elementType);
+    }
     return op->emitError("expected ")
-           << expectedPackedType << " for the packed domain value, got "
-           << packedType;
+           << expectedType << " for the packed domain value, got "
+           << actualType;
   }
   return success();
 }
@@ -5237,7 +5262,154 @@ commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
 //===----------------------------------------------------------------------===//
 
 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
-  setNameFn(getResult(), "pack");
+  if (!getResults().empty())
+    setNameFn(getResult(), "pack");
+}
+
+ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::UnresolvedOperand source, dest;
+  SmallVector<OpAsmParser::UnresolvedOperand> dynamicTiles;
+  SmallVector<OpAsmParser::UnresolvedOperand> paddingValue;
+  SmallVector<Type> paddingValueType;
+  SmallVector<int64_t> staticTiles;
+  DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
+  Type sourceType, destType, resultType;
+
+  if (parser.parseOperand(source))
+    return failure();
+
+  if (succeeded(parser.parseOptionalKeyword("padding_value"))) {
+    if (parser.parseLParen() ||
+        parser.parseOperandList(paddingValue, /*requiredOperandCount=*/1) ||
+        parser.parseColon() || parser.parseTypeList(paddingValueType) ||
+        parser.parseRParen())
+      return failure();
+  }
+
+  if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) {
+    if (parser.parseEqual())
+      return failure();
+
+    SmallVector<int64_t> outerDimsPermVec;
+    if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() {
+          int64_t value;
+          if (parser.parseInteger(value))
+            return failure();
+          outerDimsPermVec.push_back(value);
+          return success();
+        }))
+      return failure();
+    outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec);
+  }
+
+  if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual())
+    return failure();
+
+  SmallVector<int64_t> innerDimsPosVec;
+  if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() {
+        int64_t value;
+        if (parser.parseInteger(value))
+          return failure();
+        innerDimsPosVec.push_back(value);
+        return success();
+      }))
+    return failure();
+  innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec);
+
+  if (parser.parseKeyword("inner_tiles") || parser.parseEqual())
+    return failure();
+
+  DenseI64ArrayAttr staticTilesAttr;
+  if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr))
+    return failure();
+  for (auto val : staticTilesAttr.asArrayRef())
+    staticTiles.push_back(val);
+
+  if (parser.parseKeyword("into") || parser.parseOperand(dest))
+    return failure();
+
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  if (parser.parseColon() || parser.parseType(sourceType))
+    return failure();
+
+  bool hasArrow = succeeded(parser.parseOptionalArrow());
+  if (hasArrow) {
+    if (parser.parseType(destType))
+      return failure();
+  }
+
+  bool isMemRef = llvm::isa<MemRefType>(sourceType);
+  if (!hasArrow) {
+    return parser.emitError(parser.getCurrentLocation(),
+                            "pack/unpack requires '->' and destination type");
+  }
+
+  if (!isMemRef)
+    resultType = destType;
+
+  if (parser.resolveOperand(source, sourceType, result.operands) ||
+      parser.resolveOperand(dest, destType, result.operands))
+    return failure();
+
+  if (!paddingValue.empty() &&
+      parser.resolveOperands(paddingValue, paddingValueType[0],
+                             result.operands))
+    return failure();
+
+  if (!dynamicTiles.empty() &&
+      parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(),
+                             result.operands))
+    return failure();
+
+  result.addAttribute("static_inner_tiles",
+                      parser.getBuilder().getDenseI64ArrayAttr(staticTiles));
+  result.addAttribute("inner_dims_pos", innerDimsPos);
+  if (outerDimsPerm)
+    result.addAttribute("outer_dims_perm", outerDimsPerm);
+
+  SmallVector<int32_t> segmentSizes = {
+      1, 1, static_cast<int32_t>(paddingValue.size()),
+      static_cast<int32_t>(dynamicTiles.size())};
+  result.addAttribute("operandSegmentSizes",
+                      parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
+
+  if (!isMemRef)
+    result.addTypes(resultType);
+
+  return success();
+}
+
+void PackOp::print(OpAsmPrinter &p) {
+  p << " " << getSource();
+
+  if (getPaddingValue()) {
+    p << " padding_value(" << getPaddingValue() << " : "
+      << getPaddingValue().getType() << ")";
+  }
+
+  if (!getOuterDimsPerm().empty()) {
+    p << " outer_dims_perm = [";
+    llvm::interleaveComma(getOuterDimsPerm(), p);
+    p << "]";
+  }
+
+  p << " inner_dims_pos = [";
+  llvm::interleaveComma(getInnerDimsPos(), p);
+  p << "]";
+
+  p << " inner_tiles = ";
+  printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
+
+  p << " into " << getDest();
+
+  p.printOptionalAttrDict((*this)->getAttrs(),
+                          {"static_inner_tiles", "inner_dims_pos",
+                           "outer_dims_perm", "operandSegmentSizes"});
+
+  p << " : " << getSource().getType();
+  p << " -> " << getDest().getType();
 }
 
 void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
@@ -5262,6 +5434,8 @@ void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
 LogicalResult
 PackOp::reifyResultShapes(OpBuilder &builder,
                           ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  if (!hasPureTensorSemantics())
+    return failure();
   return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
 }
 
@@ -5397,13 +5571,11 @@ asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
   return result;
 }
 
-/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
-/// the packed type. Having a shared helper helps implement these two methods in
-/// a way that ensures that they agree on which dimensions are dynamic.
-static SmallVector<int64_t> getPackOpResultTypeShape(
-    ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
-    ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
-  SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
+SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
+                                              ArrayRef<int64_t> innerTileSizes,
+                                              ArrayRef<int64_t> innerDimsPos,
+                                              ArrayRef<int64_t> outerDimsPerm) {
+  SmallVector<int64_t> resultShape = llvm::to_vector(inputShape);
   for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
     if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
       continue;
@@ -5443,9 +5615,9 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
   resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
 
   SmallVector<int64_t> resultTypeShape =
-      getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims),
-                               asShapeWithAnyValueAsDynamic(innerTileSizes),
-                               innerDimsPos, outerDimsPerm);
+      inferPackedShape(asShapeWithAnyValueAsDynamic(sourceDims),
+                       asShapeWithAnyValueAsDynamic(innerTileSizes),
+                       innerDimsPos, outerDimsPerm);
 
   // Fix-up `resultDims` to ensure that they are Value's if and only if the
   // result type shape says it's a dynamic dim. This is needed as callers may
@@ -5461,15 +5633,21 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
   return resultDims;
 }
 
-/// Get the expected packed type based on source type, tile factors, position of
-/// the inner tiles and permutation of the outer tiled loop.
-RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
+RankedTensorType PackOp::inferPackedTensorType(
+    RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
+    ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
+  SmallVector<int64_t> resultShape = inferPackedShape(
+      sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
+  return RankedTensorType::get(resultShape, sourceType.getElementType());
+}
+
+MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
                                          ArrayRef<int64_t> innerTileSizes,
                                          ArrayRef<int64_t> innerDimsPos,
                                          ArrayRef<int64_t> outerDimsPerm) {
-  SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
+  SmallVector<int64_t> resultShape = inferPackedShape(
       sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
-  return RankedTensorType::get(resultShape, sourceType.getElementType());
+  return MemRefType::get(resultShape, sourceType.getElementType());
 }
 
 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
@@ -5518,6 +5696,45 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
                         getPaddingValue(), metadata.outerDimsPerm);
 }
 
+template <typename OpTy>
+static void getPackUnPackEffectsImpl(
+    OpTy op, SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+                 &effects) {
+  // No memory effects for pure tensor semantics
+  if (op.hasPureTensorSemantics())
+    return;
+
+  for (OpOperand &opOperand : op.getOperation()->getOpOperands()) {
+    if (!llvm::isa<MemRefType>(opOperand.get().getType()))
+      continue;
+
+    if (&opOperand == &op.getSourceMutable()) {
+      effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
+                           /*effectOnFullRegion=*/true,
+                           SideEffects::DefaultResource::get());
+    } else if (&opOperand == &op.getDestMutable()) {
+      effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
+                           /*effectOnFullRegion=*/true,
+                           SideEffects::DefaultResource::get());
+      effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0,
+                           /*effectOnFullRegion=*/true,
+                           SideEffects::DefaultResource::get());
+    }
+  }
+}
+
+void PackOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  getPackUnPackEffectsImpl(*this, effects);
+}
+
+void UnPackOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  getPackUnPackEffectsImpl(*this, effects);
+}
+
 /// Returns true if the tiles and the tiled dims are constant.
 template <typename OpTy>
 static bool areTilesAndTiledDimsAllConstant(OpTy op) {
@@ -5537,6 +5754,8 @@ static bool areTilesAndTiledDimsAllConstant(OpTy op) {
 }
 
 Speculation::Speculatability PackOp::getSpeculatability() {
+  if (!hasPureTensorSemantics())
+    return Speculation::NotSpeculatable;
   if (getPaddingValue())
     return Speculation::Speculatable;
 
@@ -5627,6 +5846,10 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
 }
 
 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
+  // TODO: Support Memref PackOp. Temporarily return failure.
+  if (!packOp.hasPureTensorSemantics())
+    return failure();
+
   // Fold an pack(unpack(x)) to x.
   if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
     if (unPackOp.getSourceType() == packOp.getDestType() &&
@@ -5657,7 +5880,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
           tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
     }
     Value dest = packOp.getDest();
-    RankedTensorType originalResultType = packOp.getDestType();
+    ShapedType originalResultType = packOp.getDestType();
     bool needUpdateDestType = (destShape != originalResultType.getShape());
     if (needUpdateDestType) {
       auto newDestType = packOp.getDestType().clone(destShape);
@@ -5672,9 +5895,9 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     // Insert a cast if needed
     if (needUpdateDestType) {
       rewriter.setInsertionPointAfter(packOp);
-      auto castOp =
-          tensor::CastOp::create(rewriter, loc, originalResultType, packOp);
-      rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+      auto castOp = tensor::CastOp::create(rewriter, loc, originalResultType,
+                                           packOp.getResult());
+      rewriter.replaceAllUsesExcept(packOp.getResult(), castOp, castOp);
     }
     return success();
   }
@@ -5683,8 +5906,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
 }
 
 template <typename PackOrUnpackOp>
-static bool isLikePadUnPad(PackOrUnpackOp packOp,
-                           RankedTensorType packedTensorType) {
+static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
   static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
                     std::is_same<PackOrUnpackOp, UnPackOp>::value,
                 "Function meant for pack/unpack");
@@ -5717,19 +5939,25 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp,
 
 bool PackOp::isLikePad() {
   auto packedTensorType =
-      llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
+      llvm::cast<ShapedType>((*this)->getResultTypes().front());
   return isLikePadUnPad(*this, packedTensorType);
 }
 
-OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
+::mlir::LogicalResult
+PackOp::fold(FoldAdaptor adaptor,
+             ::llvm::SmallVectorImpl<OpFoldResult> &results) {
+  if (!hasPureTensorSemantics())
+    return failure();
   std::optional<Attribute> paddingValue;
   if (auto pad = adaptor.getPaddingValue())
     paddingValue = pad;
   if (OpFoldResult reshapedSource = reshapeConstantSource(
           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
-          getDestType(), paddingValue))
-    return reshapedSource;
-  return {};
+          cast<TensorType>(getDestType()), paddingValue)) {
+    results.push_back(reshapedSource);
+    return success();
+  }
+  return failure();
 }
 
 /// Folds a tensor.cast op into a consuming PackOp op if the
@@ -5751,6 +5979,10 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
 
   LogicalResult matchAndRewrite(PackOp op,
                                 PatternRewriter &rewriter) const override {
+    // TODO: Support Memref PackOp. Temporarily return failure.
+    if (!op.hasPureTensorSemantics())
+      return failure();
+
     if (!tensor::hasFoldableTensorCastOperand(op))
       return failure();
 
@@ -5793,12 +6025,141 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
 
 void UnPackOp::getAsmResultNames(
     function_ref<void(Value, StringRef)> setNameFn) {
-  setNameFn(getResult(), "unpack");
+  if (!getResults().empty())
+    setNameFn(getResult(), "unpack");
+}
+
+// Custom parser for UnPackOp that handles the memref/tensor case distinction
+ParseResult UnPackOp::parse(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::UnresolvedOperand source, dest;
+  SmallVector<OpAsmParser::UnresolvedOperand> dynamicTiles;
+  SmallVector<int64_t> staticTiles;
+  DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
+  Type sourceType, destType, resultType;
+
+  if (parser.parseOperand(source))
+    return failure();
+
+  if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) {
+    if (parser.parseEqual())
+      return failure();
+
+    SmallVector<int64_t> outerDimsPermVec;
+    if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() {
+          int64_t value;
+          if (parser.parseInteger(value))
+            return failure();
+          outerDimsPermVec.push_back(value);
+          return success();
+        }))
+      return failure();
+    outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec);
+  }
+
+  if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual())
+    return failure();
+
+  SmallVector<int64_t> innerDimsPosVec;
+  if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() {
+        int64_t value;
+        if (parser.parseInteger(value))
+          return failure();
+        innerDimsPosVec.push_back(value);
+        return success();
+      }))
+    return failure();
+  innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec);
+
+  if (parser.parseKeyword("inner_tiles") || parser.parseEqual())
+    return failure();
+
+  DenseI64ArrayAttr staticTilesAttr;
+  if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr))
+    return failure();
+  for (auto val : staticTilesAttr.asArrayRef())
+    staticTiles.push_back(val);
+
+  if (parser.parseKeyword("into") || parser.parseOperand(dest))
+    return failure();
+
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  if (parser.parseColon() || parser.parseType(sourceType))
+    return failure();
+
+  bool hasArrow = succeeded(parser.parseOptionalArrow());
+  if (hasArrow) {
+    if (parser.parseType(destType))
+      return failure();
+  }
+
+  bool isMemRef = llvm::isa<MemRefType>(sourceType);
+  if (!hasArrow) {
+    return parser.emitError(parser.getCurrentLocation(),
+                            "pack/unpack requires '->' and destination type");
+  }
+
+  if (!isMemRef)
+    resultType = destType;
+
+  if (parser.resolveOperand(source, sourceType, result.operands) ||
+      parser.resolveOperand(dest, destType, result.operands))
+    return failure();
+
+  if (!dynamicTiles.empty() &&
+      parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(),
+                             result.operands))
+    return failure();
+
+  result.addAttribute("static_inner_tiles",
+                      parser.getBuilder().getDenseI64ArrayAttr(staticTiles));
+  result.addAttribute("inner_dims_pos", innerDimsPos);
+  if (outerDimsPerm)
+    result.addAttribute("outer_dims_perm", outerDimsPerm);
+
+  SmallVector<int32_t> segmentSizes = {
+      1, 1, 0, static_cast<int32_t>(dynamicTiles.size())};
+  result.addAttribute("operandSegmentSizes",
+                      parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
+
+  if (!isMemRef)
+    result.addTypes(resultType);
+
+  return success();
+}
+
+void UnPackOp::print(OpAsmPrinter &p) {
+  p << " " << getSource();
+
+  if (!getOuterDimsPerm().empty()) {
+    p << " outer_dims_perm = [";
+    llvm::interleaveComma(getOuterDimsPerm(), p);
+    p << "]";
+  }
+
+  p << " inner_dims_pos = [";
+  llvm::interleaveComma(getInnerDimsPos(), p);
+  p << "]";
+
+  p << " inner_tiles = ";
+  printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
+
+  p << " into " << getDest();
+
+  p.printOptionalAttrDict((*this)->getAttrs(),
+                          {"static_inner_tiles", "inner_dims_pos",
+                           "outer_dims_perm", "operandSegmentSizes"});
+
+  p << " : " << getSource().getType();
+  p << " -> " << getDest().getType();
 }
 
 LogicalResult
 UnPackOp::reifyResultShapes(OpBuilder &builder,
                             ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  if (!hasPureTensorSemantics())
+    return failure();
   return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
 }
 
@@ -5843,6 +6204,8 @@ LogicalResult UnPackOp::verify() {
 }
 
 Speculation::Speculatability UnPackOp::getSpeculatability() {
+  if (!hasPureTensorSemantics())
+    return Speculation::NotSpeculatable;
   // See PackOp::getSpeculatability.
   if (!areTilesAndTiledDimsAllConstant(*this))
     return Speculation::NotSpeculatable;
@@ -5949,6 +6312,10 @@ static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
 
 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
                                      PatternRewriter &rewriter) {
+  // TODO: Support Memref UnPackOp. Temporarily return failure.
+  if (!unPackOp.hasPureTensorSemantics())
+    return failure();
+
   /// unpack(pack(x)) -> x
   if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
     if (packOp.getSourceType() != unPackOp.getDestType())
@@ -6005,11 +6372,11 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
       dest = tensor::CastOp::create(rewriter, loc, newDestType,
                                     unPackOp.getDest());
     }
-    Value newOp = UnPackOp::create(
+    UnPackOp newOp = UnPackOp::create(
         rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
         unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
     rewriter.replaceOpWithNewOp<tensor::CastOp>(
-        unPackOp, unPackOp.getResult().getType(), newOp);
+        unPackOp, unPackOp.getResult().getType(), newOp.getResult());
     return success();
   }
 
@@ -6043,16 +6410,24 @@ bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
 }
 
 bool UnPackOp::isLikeUnPad() {
-  RankedTensorType packedTensorType = getSourceType();
+  ShapedType packedTensorType = getSourceType();
   return isLikePadUnPad(*this, packedTensorType);
 }
 
-OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
+::mlir::LogicalResult
+UnPackOp::fold(FoldAdaptor adaptor,
+               ::llvm::SmallVectorImpl<OpFoldResult> &results) {
+  // TODO: Support Memref UnPackOp. Temporarily return failure.
+  if (!hasPureTensorSemantics())
+    return failure();
+
   if (OpFoldResult reshapedSource = reshapeConstantSource(
           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
-          getResult().getType()))
-    return reshapedSource;
-  return {};
+          cast<TensorType>(getResult().getType()))) {
+    results.push_back(reshapedSource);
+    return success();
+  }
+  return failure();
 }
 
 /// Folds a tensor.cast op into a consuming UnPackOp op if the
@@ -6074,6 +6449,10 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
 
   LogicalResult matchAndRewrite(UnPackOp op,
                                 PatternRewriter &rewriter) const override {
+    // TODO: Support Memref UnPackOp. Temporarily return failure.
+    if (!op.hasPureTensorSemantics())
+      return failure();
+
     if (!tensor::hasFoldableTensorCastOperand(op))
       return failure();
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 6912da3ffbc83..6ea1eb50b13ce 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -90,6 +90,10 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
                       linalg::PackOp packOp, AffineMap operandMap,
                       ArrayRef<unsigned> blocksStartDimPos,
                       bool transposeOuterBlocks, bool transposeInnerBlocks) {
+  // TODO: Support Memref PackOp. Temporarily return failure.
+  if (!packOp.hasPureTensorSemantics())
+    return failure();
+
   assert(operandMap.getNumDims() >= 4 &&
          "expected at least 4D prepacked matmul");
   assert(blocksStartDimPos.size() >= 2 &&

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 419f6a0d3c010..d36ca43a6cbb3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -282,8 +282,8 @@ static bool getPackedOperandDetails(
       });
   bool requirePadding = linalg::PackOp::requirePaddingValueStrict(
       inputType.getShape(), innerDimsPos,
-      linalg::PackOp::inferPackedType(inputType, maybeIntInnerTileSizes,
-                                      innerDimsPos, outerDimsPerm)
+      linalg::PackOp::inferPackedTensorType(inputType, maybeIntInnerTileSizes,
+                                            innerDimsPos, outerDimsPerm)
           .getShape(),
       outerDimsPerm, innerTileSizes);
   currOperandDetails.innerDimsPos = innerDimsPos;
@@ -341,10 +341,11 @@ static std::tuple<Value, AffineMap> getOrCreatePackedViewOfOperand(
       b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
   auto poison = ub::PoisonOp::create(
       b, loc, getElementTypeOrSelf(opOperand->get().getType()));
-  Value packedOperand =
+  PackOp packedOperand =
       linalg::PackOp::create(b, loc, opOperand->get(), empty, innerDimsPos,
                              innerTileSizes, poison, outerDimsPerm);
-  return std::make_tuple(packedOperand, currOperandDetails.indexingMap);
+  return std::make_tuple(packedOperand.getResult(),
+                         currOperandDetails.indexingMap);
 }
 
 /// This function is a helper subroutine to pack a genericOp and return it. It
@@ -571,6 +572,9 @@ struct BubbleUpPackOpThroughGenericOpPattern
 
   LogicalResult matchAndRewrite(linalg::PackOp packOp,
                                 PatternRewriter &rewriter) const override {
+    if (!packOp.hasPureTensorSemantics())
+      return failure();
+
     auto genericOp = bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn,
                                                     poisonPaddingOk);
     if (failed(genericOp))
@@ -594,6 +598,9 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
 
   LogicalResult matchAndRewrite(linalg::PackOp packOp,
                                 PatternRewriter &rewriter) const override {
+    if (!packOp.hasPureTensorSemantics())
+      return failure();
+
     auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
     if (!padOp)
       return failure();
@@ -653,19 +660,19 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
     lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
     highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
 
-    auto newPadOp =
-        tensor::PadOp::create(rewriter, loc, /*result=*/Type(), sourcePack,
-                              lowPad, highPad, paddingVal, padOp.getNofold());
+    auto newPadOp = tensor::PadOp::create(
+        rewriter, loc, /*result=*/Type(), sourcePack.getResult(), lowPad,
+        highPad, paddingVal, padOp.getNofold());
 
     // If the pad has more than one user, create an unpack on the new pad to
     // replace the other uses.
     if (!padOp->hasOneUse()) {
       auto unpackEmpty = linalg::UnPackOp::createDestinationTensor(
           rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
-      Value unpackedPad =
+      UnPackOp unpackedPad =
           linalg::UnPackOp::create(rewriter, loc, newPadOp, unpackEmpty,
                                    innerDimsPos, mixedTiles, outerDimsPerm);
-      rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack);
+      rewriter.replaceAllUsesExcept(padOp, unpackedPad.getResult(), sourcePack);
     }
 
     // Replace the pack with the new pad.
@@ -763,6 +770,9 @@ static LogicalResult
 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
                                    linalg::PackOp packOp,
                                    PatternRewriter &rewriter) {
+  if (!packOp.hasPureTensorSemantics())
+    return failure();
+
   SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
   ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
   ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
@@ -812,8 +822,8 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
   }
 
   auto newCollapseOp = tensor::CollapseShapeOp::create(
-      rewriter, collapseOp.getLoc(), packOp.getType(), newPackOp,
-      newReassocIndices);
+      rewriter, collapseOp.getLoc(), packOp.getResult().getType(),
+      newPackOp.getResult(), newReassocIndices);
   rewriter.replaceOp(packOp, newCollapseOp);
 
   return success();
@@ -868,6 +878,9 @@ static LogicalResult
 bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
                                  linalg::PackOp packOp,
                                  PatternRewriter &rewriter) {
+  if (!packOp.hasPureTensorSemantics())
+    return failure();
+
   // Outer dimensions permutation is not supported currently.
   // TODO: Handle outer_dims_perm variants.
   ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
@@ -918,7 +931,7 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
   // If reassociation is not possible, then reordering cannot happen.
   // This can be caused by pack padding affecting previously expanded
   // dimensions or packing extending dimensions.
-  RankedTensorType newPackType = linalg::PackOp::inferPackedType(
+  RankedTensorType newPackType = linalg::PackOp::inferPackedTensorType(
       expandOp.getSrcType(), packOp.getStaticInnerTiles(),
       projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
   auto reassocExpand =
@@ -930,14 +943,14 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
   Value destTensor = linalg::PackOp::createDestinationTensor(
       rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
       projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
-  Value packedVal = linalg::PackOp::create(
+  PackOp packedVal = linalg::PackOp::create(
       rewriter, packOp.getLoc(), expandOp.getSrc(), destTensor,
       projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
       /*outerDimsPerm=*/SmallVector<int64_t>{});
 
-  Value newExpandOp = tensor::ExpandShapeOp::create(rewriter, packOp.getLoc(),
-                                                    packOp.getDestType(),
-                                                    packedVal, *reassocExpand);
+  Value newExpandOp = tensor::ExpandShapeOp::create(
+      rewriter, packOp.getLoc(), packOp.getDestType(), packedVal.getResult(),
+      *reassocExpand);
   rewriter.replaceOp(packOp, newExpandOp);
 
   return success();
@@ -951,6 +964,9 @@ class BubbleUpPackOpThroughReshapeOp final
 
   LogicalResult matchAndRewrite(linalg::PackOp packOp,
                                 PatternRewriter &rewriter) const override {
+    if (!packOp.hasPureTensorSemantics())
+      return failure();
+
     Operation *srcOp = packOp.getSource().getDefiningOp();
     // Currently only support when the pack op is the only user.
     if (!srcOp || !(srcOp->getNumResults() == 1) ||
@@ -1001,6 +1017,9 @@ class BubbleUpPackOpThroughReshapeOp final
 static LogicalResult pushDownUnPackOpThroughExpandShape(
     linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
     PatternRewriter &rewriter, ControlPropagationFn controlFn) {
+  if (!unPackOp.hasPureTensorSemantics())
+    return failure();
+
   // User controlled propagation function.
   if (!controlFn(&expandOp.getSrcMutable()))
     return failure();
@@ -1048,7 +1067,7 @@ static LogicalResult pushDownUnPackOpThroughExpandShape(
     nextPos += 1;
   }
 
-  RankedTensorType newExpandType = linalg::PackOp::inferPackedType(
+  RankedTensorType newExpandType = linalg::PackOp::inferPackedTensorType(
       expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
   auto newExpandOp =
       tensor::ExpandShapeOp::create(rewriter, expandOp.getLoc(), newExpandType,
@@ -1075,6 +1094,9 @@ class PushDownUnPackOpThroughReshapeOp final
 
   LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp,
                                 PatternRewriter &rewriter) const override {
+    if (!unPackOp.hasPureTensorSemantics())
+      return failure();
+
     Value result = unPackOp.getResult();
     // Currently only support unpack op with the single user.
     if (!result.hasOneUse()) {
@@ -1274,6 +1296,9 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
     if (!unpackOp)
       return failure();
 
+    if (!unpackOp.hasPureTensorSemantics())
+      return failure();
+
     if (!controlFn(&padOp.getSourceMutable()))
       return failure();
 
@@ -1313,7 +1338,7 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
         tensor::EmptyOp::create(rewriter, loc, padOp.getResultType().getShape(),
                                 padOp.getResultType().getElementType());
 
-    Value replacement = linalg::UnPackOp::create(
+    UnPackOp replacement = linalg::UnPackOp::create(
         rewriter, loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
         unpackOp.getMixedTiles(), outerDimsPerm);
     rewriter.replaceOp(padOp, replacement);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 1d4c11e418006..993eae62535c3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
@@ -110,8 +111,11 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
                                 PatternRewriter &rewriter) const override {
     if (packOp.getPaddingValue())
       return rewriter.notifyMatchFailure(packOp, "expects no padding value");
+    // TODO: Support Memref PackOp. Temporarily return failure.
+    if (!packOp.hasPureTensorSemantics())
+      return failure();
 
-    RankedTensorType sourceType = packOp.getSourceType();
+    ShapedType sourceType = packOp.getSourceType();
     if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
         failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
                           packOp.getStaticTiles())) &&
@@ -119,7 +123,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
       return failure();
     }
 
-    RankedTensorType destType = packOp.getDestType();
+    ShapedType destType = packOp.getDestType();
     auto reassociation =
         getReassociationIndicesForReshape(sourceType, destType);
     if (!reassociation)
@@ -157,8 +161,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
           "expects outer_dims_perm is empty or an identity permutation");
     }
 
-    RankedTensorType sourceType = unpackOp.getSourceType();
-    RankedTensorType destType = unpackOp.getDestType();
+    ShapedType sourceType = unpackOp.getSourceType();
+    ShapedType destType = unpackOp.getDestType();
     if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
       return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
 
@@ -173,7 +177,11 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
 
   LogicalResult matchAndRewrite(UnPackOp unpackOp,
                                 PatternRewriter &rewriter) const override {
-    RankedTensorType destType = unpackOp.getDestType();
+    // TODO: Support Memref UnPackOp. Temporarily return failure.
+    if (!unpackOp.hasPureTensorSemantics())
+      return failure();
+
+    ShapedType destType = unpackOp.getDestType();
     if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
         failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
                           unpackOp.getStaticTiles())) &&
@@ -181,7 +189,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
       return failure();
     }
 
-    RankedTensorType sourceType = unpackOp.getSourceType();
+    ShapedType sourceType = unpackOp.getSourceType();
     auto reassociation =
         getReassociationIndicesForReshape(sourceType, destType);
     if (!reassociation)
@@ -225,7 +233,7 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
     // sizes - that is because it would be impossible to compute the padding
     // size and hence to establish whether "artificial" padding would be
     // created.
-    RankedTensorType unpackedType = packOp.getSourceType();
+    ShapedType unpackedType = packOp.getSourceType();
     SmallVector<int64_t> outerShapeWithoutTranspose =
         getPackedOuterShapeWithoutTransposition(packOp);
     for (auto [pos, tileSize, high] :
@@ -274,6 +282,10 @@ struct FoldUnpackWithExtractSliceOp
     if (!unpackOp)
       return failure();
 
+    // TODO: Support Memref UnPackOp. Temporarily return failure.
+    if (!unpackOp.hasPureTensorSemantics())
+      return failure();
+
     // User controlled folding function.
     if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
       return failure();
@@ -336,6 +348,10 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
     if (!packOp)
       return failure();
 
+    // TODO: Support Memref PackOp. Temporarily return failure.
+    if (!packOp.hasPureTensorSemantics())
+      return failure();
+
     // User controlled folding function.
     if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
       return failure();
@@ -395,6 +411,10 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
 
   LogicalResult matchAndRewrite(PackOp packOp,
                                 PatternRewriter &rewriter) const override {
+    // TODO: Support Memref PackOp. Temporarily return failure.
+    if (!packOp.hasPureTensorSemantics())
+      return failure();
+
     auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
     if (!linalgOp)
       return failure();
@@ -456,6 +476,10 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
     if (!unPackOp)
       return failure();
 
+    // TODO: Support Memref UnPackOp. Temporarily return failure.
+    if (!unPackOp.hasPureTensorSemantics())
+      return failure();
+
     // User controlled folding function.
     if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
       return failure();
@@ -504,6 +528,10 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
 
   LogicalResult matchAndRewrite(UnPackOp unPackOp,
                                 PatternRewriter &rewriter) const override {
+    // TODO: Support Memref UnPackOp. Temporarily return failure.
+    if (!unPackOp.hasPureTensorSemantics())
+      return failure();
+
     auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
     if (!linalgOp)
       return failure();
@@ -568,6 +596,10 @@ struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> {
 
   LogicalResult matchAndRewrite(PackOp packOp,
                                 PatternRewriter &rewriter) const override {
+    // TODO: Support Memref PackOp. Temporarily return failure.
+    if (!packOp.hasPureTensorSemantics())
+      return failure();
+
     // Check for tensor.empty source.
     auto emptyOp = packOp.getSource().getDefiningOp<tensor::EmptyOp>();
     if (!emptyOp)
@@ -592,6 +624,10 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
 
   LogicalResult matchAndRewrite(UnPackOp unPackOp,
                                 PatternRewriter &rewriter) const override {
+    // TODO: Support Memref UnPackOp. Temporarily return failure.
+    if (!unPackOp.hasPureTensorSemantics())
+      return failure();
+
     // Check for tensor.empty source.
     auto emptyOp = unPackOp.getSource().getDefiningOp<tensor::EmptyOp>();
     if (!emptyOp)

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 50a84ace09258..5d39c4731dd1b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -766,6 +766,10 @@ struct PackOpTiling
                          ArrayRef<OpFoldResult> offsets,
                          ArrayRef<OpFoldResult> sizes) const {
     auto packOp = cast<PackOp>(op);
+    // TODO: Support Memref PackOp. Temporarily return failure.
+    if (!packOp.hasPureTensorSemantics())
+      return failure();
+
     Location loc = packOp.getLoc();
 
     // The tiling is applied on interchanged dimensions. We have to undo the
@@ -1010,6 +1014,10 @@ struct PackOpTiling
     ArrayRef<OpFoldResult> sizes(allSizes[0]);
 
     auto packOp = cast<PackOp>(op);
+    // TODO: Support Memref UnPackOp. Temporarily return failure.
+    if (!packOp.hasPureTensorSemantics())
+      return failure();
+
     Location loc = packOp.getLoc();
 
     int64_t inputRank = packOp.getSourceRank();
@@ -1189,6 +1197,10 @@ struct UnPackOpTiling
                          ArrayRef<OpFoldResult> offsets,
                          ArrayRef<OpFoldResult> sizes) const {
     auto unpackOp = cast<UnPackOp>(op);
+    // TODO: Support Memref UnPackOp. Temporarily return failure.
+    if (!unpackOp.hasPureTensorSemantics())
+      return failure();
+
     int64_t srcRank = unpackOp.getSourceRank();
     int64_t destRank = unpackOp.getDestRank();
     int64_t numInnerTiles = srcRank - destRank;
@@ -1359,6 +1371,10 @@ struct UnPackOpTiling
       return failure();
     }
     auto unPackOp = cast<UnPackOp>(op);
+    // TODO: Support Memref UnPackOp. Temporarily return failure.
+    if (!unPackOp.hasPureTensorSemantics())
+      return failure();
+
     ArrayRef<OpFoldResult> offsets(allOffsets[0]);
     ArrayRef<OpFoldResult> sizes(allSizes[0]);
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index fc7cdad0ee33d..48ebd1644bbef 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -26,6 +26,8 @@
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
@@ -218,6 +220,10 @@ struct PackedOperandsDimList {
 FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
                                              linalg::PackOp packOp,
                                              bool lowerPadLikeWithInsertSlice) {
+  // TODO: Support Memref PackOp. Temporarily return failure.
+  if (!packOp.hasPureTensorSemantics())
+    return failure();
+
   // 1. Filter out NYI cases.
   auto packedTensorType =
       cast<RankedTensorType>(packOp->getResultTypes().front());
@@ -345,11 +351,15 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
 FailureOr<LowerUnPackOpResult>
 linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
                     bool lowerUnpadLikeWithExtractSlice) {
+  // TODO: Support Memref UnPackOp. Temporarily return failure.
+  if (!unPackOp.hasPureTensorSemantics())
+    return failure();
+
   Location loc = unPackOp->getLoc();
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unPackOp);
 
-  RankedTensorType packedTensorType = unPackOp.getSourceType();
+  auto packedTensorType = cast<RankedTensorType>(unPackOp.getSourceType());
   int64_t packedRank = packedTensorType.getRank();
 
   OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -549,7 +559,7 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
         packOps.push_back(linalg::PackOp::create(
             rewriter, loc, operand, dest, innerPos, innerPackSizes, zero));
       }
-      inputsAndInits.push_back(packOps.back());
+      inputsAndInits.push_back(packOps.back().getResult());
     }
   }
 
@@ -576,7 +586,7 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
     unPackOps.push_back(linalg::UnPackOp::create(
         rewriter, packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
         maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
-    results.push_back(unPackOps.back());
+    results.push_back(unPackOps.back().getResult());
   }
 
   // Step 5. Replace `linalgOp`.
@@ -666,7 +676,7 @@ linalg::packTranspose(RewriterBase &rewriter, linalg::PackOp packOp,
   linalg::PackOp transposedPackOp =
       packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
 
-  if (!packOp.getResult().hasOneUse())
+  if (packOp.hasPureBufferSemantics() || !packOp.getResult().hasOneUse())
     return rewriter.notifyMatchFailure(linalgOp, "expect single pack use");
 
   OpOperand &packUse = *packOp->getUses().begin();
@@ -727,7 +737,10 @@ linalg::packTranspose(RewriterBase &rewriter, linalg::PackOp packOp,
   }
 
   // Step 4. Finally, replace packOp now that we don't need it anymore.
-  rewriter.replaceOp(packOp, transposedPackOp->getResults());
+  if (packOp.hasPureTensorSemantics())
+    rewriter.replaceOp(packOp, transposedPackOp->getResults());
+  else
+    rewriter.eraseOp(packOp);
 
   return PackTransposeResult{transposedPackOp, transposedLinalgOp,
                              transposedUnPackOp};
@@ -1019,6 +1032,10 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
 static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
                                            linalg::PackOp packOp) {
   Value input = packOp.getSource();
+  // TODO: Support Memref PackOp. Temporarily return just Op Source.
+  if (!packOp.hasPureTensorSemantics())
+    return input;
+
   if (!packOp.getPaddingValue()) {
     return input;
   }
@@ -1135,6 +1152,10 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
 
 LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
     linalg::PackOp packOp, PatternRewriter &rewriter) const {
+  // TODO: Support Memref PackOp. Temporarily return failure.
+  if (!packOp.hasPureTensorSemantics())
+    return failure();
+
   if (llvm::any_of(packOp.getTiledOuterDims(),
                    [](int64_t dim) { return dim != 1; })) {
     return rewriter.notifyMatchFailure(
@@ -1156,8 +1177,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
 
         // Check whether this dim has been permuted. Permuting unit dims is fine
         // as that's effectively a no-op.
-        if (dim < prev && (packOp.getType().getShape()[prev] != 1 ||
-                           packOp.getType().getShape()[dim] != 1))
+        if (dim < prev && (packOp.getResult().getType().getShape()[prev] != 1 ||
+                           packOp.getResult().getType().getShape()[dim] != 1))
           return false;
 
         prev = dim;
@@ -1274,6 +1295,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
 
 LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
     linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
+  if (!unpackOp.hasPureTensorSemantics())
+    return failure();
+
   int64_t destRank = unpackOp.getDestRank();
   ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
   ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c7d5dff74c5a9..6f73f4f57e50d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1877,8 +1877,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
   SmallVector<int64_t> preTransposeWriteVecSizses(writeVectorSizes);
   auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata);
   applyPermutationToVector(preTransposeWriteVecSizses, destInvPermutation);
-  auto preTransposeWriteVecType = VectorType::get(
-      preTransposeWriteVecSizses, packOp.getType().getElementType());
+  auto preTransposeWriteVecType =
+      VectorType::get(preTransposeWriteVecSizses,
+                      packOp.getResult().getType().getElementType());
 
   // Compute vector type for the _read_ opeartion. This is simply
   // pre-transpose-write-vector-type with the dimensions collapsed
@@ -1954,7 +1955,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unpackOp);
 
-  RankedTensorType unpackTensorType = unpackOp.getSourceType();
+  ShapedType unpackTensorType = unpackOp.getSourceType();
 
   ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
   bool useInBoundsInsteadOfMasking = false;
@@ -2117,6 +2118,10 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
 static LogicalResult
 vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
                               ArrayRef<int64_t> inputVectorSizes) {
+  // TODO: Support Memref UnPackOp. Temporarily return failure.
+  if (!unpackOp.hasPureTensorSemantics())
+    return failure();
+
   // If there are no input vector sizes and all shapes are static, there is
   // nothing left to check.
   if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
@@ -2454,6 +2459,10 @@ static LogicalResult vectorizeLinalgOpPrecondition(
 static LogicalResult
 vectorizePackOpPrecondition(linalg::PackOp packOp,
                             ArrayRef<int64_t> inputVectorSizes) {
+  // TODO: Support Memref PackOp. Temporarily return failure.
+  if (!packOp.hasPureTensorSemantics())
+    return failure();
+
   auto padValue = packOp.getPaddingValue();
   Attribute cstAttr;
   // TODO: Relax this condiiton

diff  --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 0a97bc03f584b..1f6d1a68fbbb8 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -308,9 +308,12 @@ def pack(
         _inner_tiles,
         static_inner_tiles,
     ) = _dispatch_mixed_values(inner_tiles)
+    dest = _get_op_result_or_value(dest)
+    result_type = dest.type if isinstance(dest.type, RankedTensorType) else None
 
     return _get_op_result_or_op_results(
         PackOp(
+            result=result_type,
             source=source,
             dest=dest,
             inner_dims_pos=inner_dims_pos,
@@ -340,9 +343,11 @@ def unpack(
         _inner_tiles,
         static_inner_tiles,
     ) = _dispatch_mixed_values(inner_tiles)
-
+    dest = _get_op_result_or_value(dest)
+    result_type = dest.type if isinstance(dest.type, RankedTensorType) else None
     return _get_op_result_or_op_results(
         UnPackOp(
+            result=result_type,
             source=source,
             dest=dest,
             inner_dims_pos=inner_dims_pos,

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index f4020ede4854e..08eb40d4cb442 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -2058,3 +2058,59 @@ func.func @no_fold_extract_slice_into_unpack_non_zero_offset(
 //  CHECK-SAME:       into %[[DEST]]
 //       CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
 //       CHECK:   return %[[SLICE]]
+
+// -----
+
+// CHECK-LABEL:   func.func @fold_cast_unpack_dynamic_tile_size(
+// CHECK-SAME:      %[[SRC:.*]]: tensor<1x1x8x1xi32>,
+// CHECK-SAME:      %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> {
+// CHECK:           %[[RES:.*]] = linalg.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {test_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+func.func @fold_cast_unpack_dynamic_tile_size(
+  %src: tensor<1x1x8x1xi32>,
+  %res: tensor<7x?xi32>) -> tensor<7x?xi32> {
+
+    %cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+    %c8 = arith.constant 8 : index
+    %unpack = linalg.unpack %cast
+      inner_dims_pos = [0, 1]
+      inner_tiles = [%c8, 1]
+      into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
+    return %unpack : tensor<7x?xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_pack_unpack_tensor
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x3xf32>) -> tensor<2x3xf32>
+// CHECK:       return %[[ARG0]] : tensor<2x3xf32>
+func.func @fold_pack_unpack_tensor(%x: tensor<2x3xf32>) -> tensor<2x3xf32> {
+  %unpacked = linalg.unpack %x outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
+             into %x : tensor<2x3xf32> -> tensor<2x3xf32>
+  %packed = linalg.pack %unpacked outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
+             into %x : tensor<2x3xf32> -> tensor<2x3xf32>
+  return %packed : tensor<2x3xf32>
+}
+
+// -----
+
+// Test that pack/unpack canonicalization is disabled for memref versions.
+// CHECK-LABEL: func.func @pack_unpack_memref_no_canonicalization
+// CHECK: linalg.pack
+// CHECK: linalg.unpack
+func.func @pack_unpack_memref_no_canonicalization(%source: memref<128x256xf32>, %packed: memref<16x8x8x32xf32>, %dest: memref<128x256xf32>) {
+  linalg.pack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %packed : memref<128x256xf32> -> memref<16x8x8x32xf32>
+  linalg.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
+  return
+}
+
+// -----
+
+// Test that unpack/pack canonicalization is disabled for memref versions.
+// CHECK-LABEL: func.func @unpack_pack_memref_no_canonicalization
+// CHECK: linalg.unpack
+// CHECK: linalg.pack
+func.func @unpack_pack_memref_no_canonicalization(%packed: memref<16x8x8x32xf32>, %unpacked: memref<128x256xf32>, %dest: memref<16x8x8x32xf32>) {
+  linalg.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %unpacked : memref<16x8x8x32xf32> -> memref<128x256xf32>
+  linalg.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %dest : memref<128x256xf32> -> memref<16x8x8x32xf32>
+  return
+}

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 74928920c695a..bfb92c3289a49 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -755,3 +755,77 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt:
 // CHECK-LABEL: func @conv2d_channel_first_q_promote(
 // CHECK:   %[[arg0:[a-zA-z0-9]*]]: tensor<100x3x224x224xi8>, %[[arg1:[a-zA-z0-9]*]]: tensor<64x3x5x5xi8>, %[[arg2:[a-zA-z0-9]*]]: i8, %[[arg3:[a-zA-z0-9]*]]: i8)
 // CHECK:         linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32>
+
+// -----
+
+func.func @pack_memref(%source: memref<128x256xf32>, %dest: memref<8x16x8x32xf32>) {
+  linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+      into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32>
+  return
+}
+
+// CHECK-LABEL: func @pack_memref(
+// CHECK:   %[[source:[a-zA-z0-9]*]]: memref<128x256xf32>, %[[dest:[a-zA-z0-9]*]]: memref<8x16x8x32xf32>) {
+// CHECK:     linalg.pack %[[source]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[dest]] : memref<128x256xf32> -> memref<8x16x8x32xf32>
+// CHECK:   return
+// CHECK: }
+// -----
+
+func.func @unpack_memref(%source: memref<16x8x8x32xf32>, %dest: memref<128x256xf32>) {
+  linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+      into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
+  return
+}
+
+// CHECK-LABEL: func @unpack_memref(
+// CHECK:   %[[source:[a-zA-z0-9]*]]: memref<16x8x8x32xf32>, %[[dest:[a-zA-z0-9]*]]: memref<128x256xf32>) {
+// CHECK:         linalg.unpack %[[source]] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[dest]] : memref<16x8x8x32xf32> -> memref<128x256xf32>
+// CHECK:   return
+
+// -----
+
+// CHECK-LABEL: func @test_pack_memref
+func.func @test_pack_memref(%arg0: memref<128x256xf32>, %arg1: memref<16x8x8x32xf32>) {
+  // CHECK-NOT: %{{.*}} = linalg.pack
+  // CHECK: linalg.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : memref<128x256xf32> -> memref<16x8x8x32xf32>
+  linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : memref<128x256xf32> -> memref<16x8x8x32xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @test_unpack_memref
+func.func @test_unpack_memref(%arg0: memref<16x8x8x32xf32>, %arg1: memref<128x256xf32>) {
+  // CHECK: linalg.unpack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : memref<16x8x8x32xf32>
+  linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : memref<16x8x8x32xf32> -> memref<128x256xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @test_pack_memref_with_padding
+func.func @test_pack_memref_with_padding(%arg0: memref<127x255xf32>, %arg1: memref<16x8x8x32xf32>, %pad: f32) {
+  // CHECK: linalg.pack %{{.*}} padding_value(%{{.*}} : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : memref<127x255xf32>
+  linalg.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : memref<127x255xf32> -> memref<16x8x8x32xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @test_pack_tensor
+func.func @test_pack_tensor(%arg0: tensor<128x256xf32>, %arg1: tensor<16x8x8x32xf32>) -> tensor<16x8x8x32xf32> {
+  // CHECK: %[[RESULT:.*]] = linalg.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : tensor<128x256xf32> -> tensor<16x8x8x32xf32>
+  %0 = linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : tensor<128x256xf32> -> tensor<16x8x8x32xf32>
+  // CHECK: return %[[RESULT]] : tensor<16x8x8x32xf32>
+  return %0 : tensor<16x8x8x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @test_unpack_tensor
+func.func @test_unpack_tensor(%arg0: tensor<16x8x8x32xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> {
+  // CHECK: %[[RESULT:.*]] = linalg.unpack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : tensor<16x8x8x32xf32> -> tensor<128x256xf32>
+  %0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : tensor<16x8x8x32xf32> -> tensor<128x256xf32>
+  // CHECK: return %[[RESULT]] : tensor<128x256xf32>
+  return %0 : tensor<128x256xf32>
+}

diff  --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 92591cd59fb40..68b32098b7782 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -664,6 +664,20 @@ def tensor_pack(src, dst):
 
                 return unpacked
 
+            @func.FuncOp.from_py_func(
+                MemRefType.get((128, 128), f32),
+                MemRefType.get((16, 16, 8, 8), f32),
+            )
+            def memref_pack(src, dst):
+                linalg.pack(src, dst, inner_dims_pos=[1, 0], inner_tiles=[8, 8])
+
+                linalg.unpack(
+                    dst,
+                    src,
+                    inner_dims_pos=[0, 1],
+                    inner_tiles=[8, 8],
+                )
+
         # CHECK-LABEL:   func.func @tensor_pack(
         # CHECK-SAME:      %[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<16x16x8x8xf32>) -> tensor<128x128xf32> {
         # CHECK:           %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
@@ -671,6 +685,12 @@ def tensor_pack(src, dst):
         # CHECK:           %[[VAL_4:.*]] = linalg.unpack %[[VAL_3]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %[[VAL_0]] : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
         # CHECK:           return %[[VAL_4]] : tensor<128x128xf32>
         # CHECK:         }
+        # CHECK-LABEL:   func.func @memref_pack(
+        # CHECK-SAME:      %[[VAL_0:.*]]: memref<128x128xf32>, %[[VAL_1:.*]]: memref<16x16x8x8xf32>) {
+        # CHECK:           linalg.pack %[[VAL_0]] inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %[[VAL_1]] : memref<128x128xf32> -> memref<16x16x8x8xf32>
+        # CHECK:           linalg.unpack %[[VAL_1]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %[[VAL_0]] : memref<16x16x8x8xf32> -> memref<128x128xf32>
+        # CHECK:           return
+        # CHECK:         }
         print(module)
 
 


        


More information about the Mlir-commits mailing list