[Mlir-commits] [mlir] [draft][mlir][linalg] pack, unpack to take memref inputs (PR #129036)

Hyunsung Lee llvmlistbot at llvm.org
Thu Feb 27 23:19:33 PST 2025


https://github.com/ita9naiwa updated https://github.com/llvm/llvm-project/pull/129036

>From 4d523adc3cf5eb581c43395e66aaa0012dbc179b Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Thu, 27 Feb 2025 19:54:30 +0900
Subject: [PATCH 1/7] draft

---
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    | 72 +++++++++++++++++--
 .../Dialect/Linalg/IR/RelayoutOpInterface.td  |  1 +
 .../mlir/Dialect/Utils/ReshapeOpsUtils.h      |  4 +-
 mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp    |  2 +-
 4 files changed, 69 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 1e48a5e3a20ee..f8a4657c564ce 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -77,7 +77,20 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
     /// with `inner_dims_pos` rather than the packed tensor.
     SmallVector<int64_t> getTiledOuterDims();
   }];
-
+  let extraClassDeclaration = commonExtraClassDeclaration # [{
+    ShapedType getInputType() {
+        return cast<ShapedType>(getInput().getType());
+      }
+      ShapedType getOutputType() {
+        return cast<ShapedType>(getOutput().getType());
+      }
+      int64_t getInputRank() {
+        return getInputType().getRank();
+      }
+      int64_t getOutputRank() {
+        return getOutputType().getRank();
+      }
+    }];
   let hasVerifier = 1;
 }
 
@@ -152,14 +165,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     // Note: Only tiled dimensions can be padded.
     ```
   }];
-  let arguments = (ins AnyRankedTensor:$source,
-                       AnyRankedTensor:$dest,
+  let arguments = (ins AnyShaped:$source,
+                       AnyShaped:$dest,
                        Optional<AnyType>:$padding_value,
                        DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
                        DenseI64ArrayAttr:$inner_dims_pos,
                        Variadic<Index>:$inner_tiles,
                        DenseI64ArrayAttr:$static_inner_tiles);
-  let results = (outs AnyRankedTensor:$result);
+  let results = (outs AnyShaped:$result);
   let assemblyFormat = [{
     $source
     (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
@@ -179,6 +192,28 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
+    Value getOutput() {
+      return getDpsInitOperand(0)->get();
+    }
+
+    // Return the input operand.
+    Value getInput() {
+      return getDpsInputOperand(0)->get();
+    }
+    ShapedType getInputType() {
+      return cast<ShapedType>(getInput().getType());
+    }
+    ShapedType getOutputType() {
+      return cast<ShapedType>(getDest().getType());
+    }
+    int64_t getInputRank() {
+      return getInputType().getRank();
+    }
+    int64_t getOutputRank() {
+      return getOutputType().getRank();
+    }
+
+    LogicalResult generateScalarImplementation(OpBuilder &builder, Location loc, ValueRange ivs);
     // Method to get the shape of the result as `SmallVector<OpFoldResult>`.
     // This is a static method to allow getting the shape of the destination
     // expected while creating a `pack` op.
@@ -229,6 +264,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     /// 2. pads the other ones, and
     /// 3. doesn't shuffle the dimensions
     bool isLikePad();
+
   }];
 
   let hasCanonicalizeMethod = 1;
@@ -279,13 +315,13 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
         : tensor<8x16x8x32xf32> -> tensor<128x256xf32>
     ```
   }];
-  let arguments = (ins AnyRankedTensor:$source,
-                       AnyRankedTensor:$dest,
+  let arguments = (ins AnyShaped:$source,
+                       AnyShaped:$dest,
                        DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
                        DenseI64ArrayAttr:$inner_dims_pos,
                        Variadic<Index>:$inner_tiles,
                        DenseI64ArrayAttr:$static_inner_tiles);
-  let results = (outs AnyRankedTensor:$result);
+  let results = (outs AnyShaped:$result);
   let assemblyFormat = [{
     $source
     (`outer_dims_perm` `=` $outer_dims_perm^)?
@@ -303,6 +339,28 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
+    Value getOutput() {
+      return getDpsInitOperand(0)->get();
+    }
+
+    // Return the input operand.
+    Value getInput() {
+      return getDpsInputOperand(0)->get();
+    }
+    ShapedType getInputType() {
+      return cast<ShapedType>(getInput().getType());
+    }
+    ShapedType getOutputType() {
+      return cast<ShapedType>(getDest().getType()); // getDest() 사용
+    }
+    int64_t getInputRank() {
+      return getInputType().getRank();
+    }
+    int64_t getOutputRank() {
+      return getOutputType().getRank();
+    }
+    LogicalResult generateScalarImplementation(OpBuilder &builder, Location loc, ValueRange ivs);
+
     static Value createDestinationTensor(OpBuilder &b, Location loc,
         Value source, ArrayRef<OpFoldResult> innerTileSizes,
         ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
index 2dec2fc4396f4..467d862d277eb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
@@ -10,6 +10,7 @@
 #define LINALG_IR_RELAYOUTOPINTERFACE
 
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
 include "mlir/IR/OpBase.td"
 
 def LinalgRelayoutOpInterface : OpInterface<"RelayoutOpInterface"> {
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3af89a6ab3799..a86bf74a7b6a1 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -451,7 +451,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
 ///    %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
 ///          tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
 ///
-///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] : 
+///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
 ///          tensor<1x1x1x10xf32> into tensor<1x10xf32>
 ///    %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
 ///          tensor<1x10xf32> into tensor<10x10xf32>
@@ -573,7 +573,7 @@ PackingMetadata computePackingMetadata(int64_t packedRank,
 /// Removes the op and replaces the constant with a new constant of the result
 /// shape. When an optional cst attribute is passed, it is reshaped only if the
 /// splat value matches the value in the attribute.
-OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result,
+OpFoldResult reshapeConstantSource(DenseElementsAttr source, ShapedType result,
                                    std::optional<Attribute> cst = std::nullopt);
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 0336423c57b1d..4267732571801 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -485,7 +485,7 @@ PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
 }
 
 OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
-                                         TensorType result,
+  ShapedType result,
                                          std::optional<Attribute> cst) {
   if (source && source.isSplat() && result.hasStaticShape() &&
       (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))

>From 4f2dbf4848092942a7932387e39d3c1220d78923 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Fri, 28 Feb 2025 08:00:32 +0900
Subject: [PATCH 2/7] draft

---
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    | 44 -------------------
 1 file changed, 44 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index f8a4657c564ce..6e2c6171132f5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -192,28 +192,6 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
-    Value getOutput() {
-      return getDpsInitOperand(0)->get();
-    }
-
-    // Return the input operand.
-    Value getInput() {
-      return getDpsInputOperand(0)->get();
-    }
-    ShapedType getInputType() {
-      return cast<ShapedType>(getInput().getType());
-    }
-    ShapedType getOutputType() {
-      return cast<ShapedType>(getDest().getType());
-    }
-    int64_t getInputRank() {
-      return getInputType().getRank();
-    }
-    int64_t getOutputRank() {
-      return getOutputType().getRank();
-    }
-
-    LogicalResult generateScalarImplementation(OpBuilder &builder, Location loc, ValueRange ivs);
     // Method to get the shape of the result as `SmallVector<OpFoldResult>`.
     // This is a static method to allow getting the shape of the destination
     // expected while creating a `pack` op.
@@ -339,28 +317,6 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
-    Value getOutput() {
-      return getDpsInitOperand(0)->get();
-    }
-
-    // Return the input operand.
-    Value getInput() {
-      return getDpsInputOperand(0)->get();
-    }
-    ShapedType getInputType() {
-      return cast<ShapedType>(getInput().getType());
-    }
-    ShapedType getOutputType() {
-      return cast<ShapedType>(getDest().getType()); // getDest() 사용
-    }
-    int64_t getInputRank() {
-      return getInputType().getRank();
-    }
-    int64_t getOutputRank() {
-      return getOutputType().getRank();
-    }
-    LogicalResult generateScalarImplementation(OpBuilder &builder, Location loc, ValueRange ivs);
-
     static Value createDestinationTensor(OpBuilder &b, Location loc,
         Value source, ArrayRef<OpFoldResult> innerTileSizes,
         ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);

>From 226230c9445084671531d755d5c3f5612bed7d67 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Fri, 28 Feb 2025 08:01:05 +0900
Subject: [PATCH 3/7] draft

---
 .../mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td   | 15 +--------------
 1 file changed, 1 insertion(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 6e2c6171132f5..c68c395fc6337 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -77,20 +77,7 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
     /// with `inner_dims_pos` rather than the packed tensor.
     SmallVector<int64_t> getTiledOuterDims();
   }];
-  let extraClassDeclaration = commonExtraClassDeclaration # [{
-    ShapedType getInputType() {
-        return cast<ShapedType>(getInput().getType());
-      }
-      ShapedType getOutputType() {
-        return cast<ShapedType>(getOutput().getType());
-      }
-      int64_t getInputRank() {
-        return getInputType().getRank();
-      }
-      int64_t getOutputRank() {
-        return getOutputType().getRank();
-      }
-    }];
+
   let hasVerifier = 1;
 }
 

>From 0c184dfc85cdb0d89d62aa8cafc4f752e1acc654 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Fri, 28 Feb 2025 09:44:08 +0900
Subject: [PATCH 4/7] init

---
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    | 10 +++---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 14 ++++-----
 .../Transforms/PackAndUnpackPatterns.cpp      | 12 +++----
 .../Dialect/Linalg/Transforms/Transforms.cpp  | 31 +++++++++++++++----
 .../Linalg/Transforms/Vectorization.cpp       |  2 +-
 mlir/lib/Tools/mlir-opt/launch.json           | 13 ++++++++
 6 files changed, 57 insertions(+), 25 deletions(-)
 create mode 100644 mlir/lib/Tools/mlir-opt/launch.json

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index c68c395fc6337..785c7cc924159 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -43,10 +43,10 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
   code commonExtraClassDeclaration = [{
     size_t getSourceRank() { return getSourceType().getRank(); };
     size_t getDestRank() { return getDestType().getRank(); };
-    RankedTensorType getSourceType() {
-      return ::llvm::cast<RankedTensorType>(getSource().getType()); };
-    RankedTensorType getDestType() {
-      return ::llvm::cast<RankedTensorType>(getDest().getType()); };
+    ShapedType getSourceType() {
+      return ::llvm::cast<ShapedType>(getSource().getType()); };
+    ShapedType getDestType() {
+      return ::llvm::cast<ShapedType>(getDest().getType()); };
 
     MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
 
@@ -190,7 +190,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     // Method to get the `RankedTensorType` of the result based on the inner
     // tiles, position of the inner tiles (innerDimsPos)  and interchange vector
     // of outer loops (outerDimsPerm).
-    static RankedTensorType inferPackedType(RankedTensorType sourceType,
+    static RankedTensorType inferPackedType(ShapedType sourceType,
         ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
         ArrayRef<int64_t> outerDimsPerm = {});
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..f4f08d9d4acf7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -803,7 +803,7 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
           rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
     }
 
-    RankedTensorType srcPadType = srcPadOp.getSourceType();
+    ShapedType srcPadType = srcPadOp.getSourceType();
     SmallVector<OpFoldResult, 4> newSizes;
     for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
       if (srcPadType.isDynamicDim(i)) {
@@ -4433,7 +4433,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
     return op->emitError("invalid zero tile factor");
 
   // Verify inner_dims_pos and outer_dims_perm.
-  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
+  ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
                                       ? packOrUnPack.getSourceType()
                                       : packOrUnPack.getDestType();
   size_t unpackedRank = unpackedType.getRank();
@@ -4747,7 +4747,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
 
 /// Get the expected packed type based on source type, tile factors, position of
 /// the inner tiles and permutation of the outer tiled loop.
-RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
+RankedTensorType PackOp::inferPackedType(ShapedType sourceType,
                                          ArrayRef<int64_t> innerTileSizes,
                                          ArrayRef<int64_t> innerDimsPos,
                                          ArrayRef<int64_t> outerDimsPerm) {
@@ -4943,7 +4943,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
           rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
     }
     Value dest = packOp.getDest();
-    RankedTensorType originalResultType = packOp.getDestType();
+    ShapedType originalResultType = packOp.getDestType();
     bool needUpdateDestType = (destShape != originalResultType.getShape());
     if (needUpdateDestType) {
       auto newDestType = packOp.getDestType().clone(destShape);
@@ -4953,7 +4953,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     rewriter.modifyOpInPlace(packOp, [&] {
       packOp.getSourceMutable().assign(source);
       packOp.getDestMutable().assign(dest);
-      packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+      packOp.getResult().setType(cast<ShapedType>(dest.getType()));
     });
     // Insert a cast if needed
     if (needUpdateDestType) {
@@ -4970,7 +4970,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
 
 template <typename PackOrUnpackOp>
 static bool isLikePadUnPad(PackOrUnpackOp packOp,
-                           RankedTensorType packedTensorType) {
+                           ShapedType packedTensorType) {
   static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
                     std::is_same<PackOrUnpackOp, UnPackOp>::value,
                 "Function meant for pack/unpack");
@@ -5274,7 +5274,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
 }
 
 bool UnPackOp::isLikeUnPad() {
-  RankedTensorType packedTensorType = getSourceType();
+  ShapedType packedTensorType = getSourceType();
   return isLikePadUnPad(*this, packedTensorType);
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 0984b6988b93b..599aa3b6668df 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -111,7 +111,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
     if (packOp.getPaddingValue())
       return rewriter.notifyMatchFailure(packOp, "expects no padding value");
 
-    RankedTensorType sourceType = packOp.getSourceType();
+    ShapedType sourceType = packOp.getSourceType();
     if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
         failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
                           packOp.getStaticTiles())) &&
@@ -119,7 +119,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
       return failure();
     }
 
-    RankedTensorType destType = packOp.getDestType();
+    ShapedType destType = packOp.getDestType();
     auto reassociation =
         getReassociationIndicesForReshape(sourceType, destType);
     if (!reassociation)
@@ -157,8 +157,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
           "expects outer_dims_perm is empty or an identity permutation");
     }
 
-    RankedTensorType sourceType = unpackOp.getSourceType();
-    RankedTensorType destType = unpackOp.getDestType();
+    ShapedType sourceType = unpackOp.getSourceType();
+    ShapedType destType = unpackOp.getDestType();
     if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
       return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
 
@@ -173,7 +173,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
 
   LogicalResult matchAndRewrite(UnPackOp unpackOp,
                                 PatternRewriter &rewriter) const override {
-    RankedTensorType destType = unpackOp.getDestType();
+    ShapedType destType = unpackOp.getDestType();
     if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
         failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
                           unpackOp.getStaticTiles())) &&
@@ -181,7 +181,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
       return failure();
     }
 
-    RankedTensorType sourceType = unpackOp.getSourceType();
+    ShapedType sourceType = unpackOp.getSourceType();
     auto reassociation =
         getReassociationIndicesForReshape(sourceType, destType);
     if (!reassociation)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index dcd50cc44f81b..7ed211841c53f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -359,7 +359,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unPackOp);
 
-  RankedTensorType packedTensorType = unPackOp.getSourceType();
+  ShapedType packedTensorType = unPackOp.getSourceType();
   int64_t packedRank = packedTensorType.getRank();
 
   OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -396,10 +396,29 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
   applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
 
   // 3. Transpose packedShape to stripMinedShape.
-  RankedTensorType stripMinedTensorType =
-      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
-  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
-      stripMinedTensorType, packingMetadata.reassociations);
+  ShapedType stripMinedType;
+  if (auto tensorType = packedTensorType.dyn_cast<TensorType>()) {
+    stripMinedType =
+        RankedTensorType::get(stripMinedShape, tensorType.getElementType());
+  } else if (auto memrefType = packedTensorType.dyn_cast<MemRefType>()) {
+    stripMinedType =
+        MemRefType::get(stripMinedShape, memrefType.getElementType());
+  }
+  ShapedType collapsedType;
+  if (stripMinedType.isa<TensorType>()) {
+    collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+        stripMinedType.cast<RankedTensorType>(),
+        packingMetadata.reassociations);
+  } else if (stripMinedType.isa<MemRefType>()) {
+    auto memrefTy = stripMinedType.cast<MemRefType>();
+    auto tensorTy =
+        RankedTensorType::get(memrefTy.getShape(), memrefTy.getElementType());
+    auto collapsedTensorType = tensor::CollapseShapeOp::inferCollapsedType(
+        tensorTy, packingMetadata.reassociations);
+    // tensor collapsed type을 memref로 재구성 (같은 메모리 공간 유지)
+    collapsedType = MemRefType::get(collapsedTensorType.getShape(),
+                                    collapsedTensorType.getElementType());
+  }
 
   // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
   // permutation.
@@ -407,7 +426,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
       tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
   applyPermutationToVector(dims, packedToStripMinedShapePerm);
   auto emptyOp = rewriter.create<tensor::EmptyOp>(
-      loc, dims, stripMinedTensorType.getElementType());
+      loc, dims, stripMinedType.getElementType());
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
       loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ae04c2b6b2a5b..25ad5e38addbe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1669,7 +1669,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unpackOp);
 
-  RankedTensorType unpackTensorType = unpackOp.getSourceType();
+  ShapedType unpackTensorType = unpackOp.getSourceType();
 
   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
   ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
diff --git a/mlir/lib/Tools/mlir-opt/launch.json b/mlir/lib/Tools/mlir-opt/launch.json
new file mode 100644
index 0000000000000..5a686d02e2dfb
--- /dev/null
+++ b/mlir/lib/Tools/mlir-opt/launch.json
@@ -0,0 +1,13 @@
+{
+    "version": "0.2.0",
+    "configurations": [
+        {
+            "name": "ma",
+            "type": "lldb",
+            "request": "launch",
+            "program": "/Users/ita/src/iree-build/tools/iree-opt --show-dialects",
+            "args": [],
+            "cwd": "${workspaceFolder}"
+        }
+    ]
+}

>From 19201c69e23578a69583bb98415f9c9583cb5c41 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Fri, 28 Feb 2025 14:50:46 +0900
Subject: [PATCH 5/7] lint

---
 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp |  1 -
 mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp        | 12 ++++++------
 mlir/lib/Tools/mlir-opt/launch.json               | 13 -------------
 3 files changed, 6 insertions(+), 20 deletions(-)
 delete mode 100644 mlir/lib/Tools/mlir-opt/launch.json

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 7ed211841c53f..36e01ef46b30b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -415,7 +415,6 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
         RankedTensorType::get(memrefTy.getShape(), memrefTy.getElementType());
     auto collapsedTensorType = tensor::CollapseShapeOp::inferCollapsedType(
         tensorTy, packingMetadata.reassociations);
-    // tensor collapsed type을 memref로 재구성 (같은 메모리 공간 유지)
     collapsedType = MemRefType::get(collapsedTensorType.getShape(),
                                     collapsedTensorType.getElementType());
   }
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 4267732571801..9a2bd3493f6af 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -315,11 +315,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
     // have proven that these are not sliced. In this case we just take
     // the full extent of each dimension in the reassociation list.
     if (linearizedDimensions[it.index()]) {
-      llvm::append_range(
-          offsetsSizesAndStrides,
-          llvm::map_range(it.value(), [&](int64_t idx) -> Range {
-            return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
-          }));
+      llvm::append_range(offsetsSizesAndStrides,
+                         llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+                           return {zeroAttr, collapseShapeInputShape[idx],
+                                   oneAttr};
+                         }));
       continue;
     }
 
@@ -485,7 +485,7 @@ PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
 }
 
 OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
-  ShapedType result,
+                                         ShapedType result,
                                          std::optional<Attribute> cst) {
   if (source && source.isSplat() && result.hasStaticShape() &&
       (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
diff --git a/mlir/lib/Tools/mlir-opt/launch.json b/mlir/lib/Tools/mlir-opt/launch.json
deleted file mode 100644
index 5a686d02e2dfb..0000000000000
--- a/mlir/lib/Tools/mlir-opt/launch.json
+++ /dev/null
@@ -1,13 +0,0 @@
-{
-    "version": "0.2.0",
-    "configurations": [
-        {
-            "name": "ma",
-            "type": "lldb",
-            "request": "launch",
-            "program": "/Users/ita/src/iree-build/tools/iree-opt --show-dialects",
-            "args": [],
-            "cwd": "${workspaceFolder}"
-        }
-    ]
-}

>From b99b92030f2f664607f43554d2b7bc722c98c2c1 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Fri, 28 Feb 2025 15:26:13 +0900
Subject: [PATCH 6/7] lint

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index f4f08d9d4acf7..eca8cea3e6323 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4434,8 +4434,8 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
 
   // Verify inner_dims_pos and outer_dims_perm.
   ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
-                                      ? packOrUnPack.getSourceType()
-                                      : packOrUnPack.getDestType();
+                                ? packOrUnPack.getSourceType()
+                                : packOrUnPack.getDestType();
   size_t unpackedRank = unpackedType.getRank();
   ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
   ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
@@ -4969,8 +4969,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
 }
 
 template <typename PackOrUnpackOp>
-static bool isLikePadUnPad(PackOrUnpackOp packOp,
-                           ShapedType packedTensorType) {
+static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
   static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
                     std::is_same<PackOrUnpackOp, UnPackOp>::value,
                 "Function meant for pack/unpack");

>From be6a1193579633d7b678a30a9a80e5dee89a51e1 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Fri, 28 Feb 2025 16:19:20 +0900
Subject: [PATCH 7/7] add

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index eca8cea3e6323..a19039fbca67d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5001,9 +5001,12 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
 }
 
 bool PackOp::isLikePad() {
-  auto packedTensorType =
-      llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
-  return isLikePadUnPad(*this, packedTensorType);
+  if (auto packedTensorType =
+          llvm::dyn_cast<RankedTensorType>((*this)->getResultTypes().front()))
+    return isLikePadUnPad(*this, packedTensorType);
+  if (auto packedTensorType =
+          llvm::dyn_cast<MemRefType>((*this)->getResultTypes().front()))
+    return isLikePadUnPad(*this, packedTensorType);
 }
 
 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {



More information about the Mlir-commits mailing list