[Mlir-commits] [mlir] [draft][mlir][linalg] pack, unpack to take memref inputs (PR #129036)

Hyunsung Lee llvmlistbot at llvm.org
Thu Feb 27 15:00:44 PST 2025


https://github.com/ita9naiwa updated https://github.com/llvm/llvm-project/pull/129036

>From 4d523adc3cf5eb581c43395e66aaa0012dbc179b Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Thu, 27 Feb 2025 19:54:30 +0900
Subject: [PATCH 1/2] draft

---
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    | 72 +++++++++++++++++--
 .../Dialect/Linalg/IR/RelayoutOpInterface.td  |  1 +
 .../mlir/Dialect/Utils/ReshapeOpsUtils.h      |  4 +-
 mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp    |  2 +-
 4 files changed, 69 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 1e48a5e3a20ee..f8a4657c564ce 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -77,7 +77,20 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
     /// with `inner_dims_pos` rather than the packed tensor.
     SmallVector<int64_t> getTiledOuterDims();
   }];
-
+  let extraClassDeclaration = commonExtraClassDeclaration # [{
+    ShapedType getInputType() {
+        return cast<ShapedType>(getInput().getType());
+      }
+      ShapedType getOutputType() {
+        return cast<ShapedType>(getOutput().getType());
+      }
+      int64_t getInputRank() {
+        return getInputType().getRank();
+      }
+      int64_t getOutputRank() {
+        return getOutputType().getRank();
+      }
+    }];
   let hasVerifier = 1;
 }
 
@@ -152,14 +165,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     // Note: Only tiled dimensions can be padded.
     ```
   }];
-  let arguments = (ins AnyRankedTensor:$source,
-                       AnyRankedTensor:$dest,
+  let arguments = (ins AnyShaped:$source,
+                       AnyShaped:$dest,
                        Optional<AnyType>:$padding_value,
                        DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
                        DenseI64ArrayAttr:$inner_dims_pos,
                        Variadic<Index>:$inner_tiles,
                        DenseI64ArrayAttr:$static_inner_tiles);
-  let results = (outs AnyRankedTensor:$result);
+  let results = (outs AnyShaped:$result);
   let assemblyFormat = [{
     $source
     (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
@@ -179,6 +192,28 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
+    Value getOutput() {
+      return getDpsInitOperand(0)->get();
+    }
+
+    // Return the input operand.
+    Value getInput() {
+      return getDpsInputOperand(0)->get();
+    }
+    ShapedType getInputType() {
+      return cast<ShapedType>(getInput().getType());
+    }
+    ShapedType getOutputType() {
+      return cast<ShapedType>(getDest().getType());
+    }
+    int64_t getInputRank() {
+      return getInputType().getRank();
+    }
+    int64_t getOutputRank() {
+      return getOutputType().getRank();
+    }
+
+    LogicalResult generateScalarImplementation(OpBuilder &builder, Location loc, ValueRange ivs);
     // Method to get the shape of the result as `SmallVector<OpFoldResult>`.
     // This is a static method to allow getting the shape of the destination
     // expected while creating a `pack` op.
@@ -229,6 +264,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     /// 2. pads the other ones, and
     /// 3. doesn't shuffle the dimensions
     bool isLikePad();
+
   }];
 
   let hasCanonicalizeMethod = 1;
@@ -279,13 +315,13 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
         : tensor<8x16x8x32xf32> -> tensor<128x256xf32>
     ```
   }];
-  let arguments = (ins AnyRankedTensor:$source,
-                       AnyRankedTensor:$dest,
+  let arguments = (ins AnyShaped:$source,
+                       AnyShaped:$dest,
                        DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
                        DenseI64ArrayAttr:$inner_dims_pos,
                        Variadic<Index>:$inner_tiles,
                        DenseI64ArrayAttr:$static_inner_tiles);
-  let results = (outs AnyRankedTensor:$result);
+  let results = (outs AnyShaped:$result);
   let assemblyFormat = [{
     $source
     (`outer_dims_perm` `=` $outer_dims_perm^)?
@@ -303,6 +339,28 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
+    Value getOutput() {
+      return getDpsInitOperand(0)->get();
+    }
+
+    // Return the input operand.
+    Value getInput() {
+      return getDpsInputOperand(0)->get();
+    }
+    ShapedType getInputType() {
+      return cast<ShapedType>(getInput().getType());
+    }
+    ShapedType getOutputType() {
+      return cast<ShapedType>(getDest().getType()); // getDest() 사용
+    }
+    int64_t getInputRank() {
+      return getInputType().getRank();
+    }
+    int64_t getOutputRank() {
+      return getOutputType().getRank();
+    }
+    LogicalResult generateScalarImplementation(OpBuilder &builder, Location loc, ValueRange ivs);
+
     static Value createDestinationTensor(OpBuilder &b, Location loc,
         Value source, ArrayRef<OpFoldResult> innerTileSizes,
         ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
index 2dec2fc4396f4..467d862d277eb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
@@ -10,6 +10,7 @@
 #define LINALG_IR_RELAYOUTOPINTERFACE
 
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
 include "mlir/IR/OpBase.td"
 
 def LinalgRelayoutOpInterface : OpInterface<"RelayoutOpInterface"> {
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3af89a6ab3799..a86bf74a7b6a1 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -451,7 +451,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
 ///    %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
 ///          tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
 ///
-///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] : 
+///    %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
 ///          tensor<1x1x1x10xf32> into tensor<1x10xf32>
 ///    %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
 ///          tensor<1x10xf32> into tensor<10x10xf32>
@@ -573,7 +573,7 @@ PackingMetadata computePackingMetadata(int64_t packedRank,
 /// Removes the op and replaces the constant with a new constant of the result
 /// shape. When an optional cst attribute is passed, it is reshaped only if the
 /// splat value matches the value in the attribute.
-OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result,
+OpFoldResult reshapeConstantSource(DenseElementsAttr source, ShapedType result,
                                    std::optional<Attribute> cst = std::nullopt);
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 0336423c57b1d..4267732571801 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -485,7 +485,7 @@ PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
 }
 
 OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
-                                         TensorType result,
+  ShapedType result,
                                          std::optional<Attribute> cst) {
   if (source && source.isSplat() && result.hasStaticShape() &&
       (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))

>From 4f2dbf4848092942a7932387e39d3c1220d78923 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Fri, 28 Feb 2025 08:00:32 +0900
Subject: [PATCH 2/2] draft

---
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    | 44 -------------------
 1 file changed, 44 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index f8a4657c564ce..6e2c6171132f5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -192,28 +192,6 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
-    Value getOutput() {
-      return getDpsInitOperand(0)->get();
-    }
-
-    // Return the input operand.
-    Value getInput() {
-      return getDpsInputOperand(0)->get();
-    }
-    ShapedType getInputType() {
-      return cast<ShapedType>(getInput().getType());
-    }
-    ShapedType getOutputType() {
-      return cast<ShapedType>(getDest().getType());
-    }
-    int64_t getInputRank() {
-      return getInputType().getRank();
-    }
-    int64_t getOutputRank() {
-      return getOutputType().getRank();
-    }
-
-    LogicalResult generateScalarImplementation(OpBuilder &builder, Location loc, ValueRange ivs);
     // Method to get the shape of the result as `SmallVector<OpFoldResult>`.
     // This is a static method to allow getting the shape of the destination
     // expected while creating a `pack` op.
@@ -339,28 +317,6 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
-    Value getOutput() {
-      return getDpsInitOperand(0)->get();
-    }
-
-    // Return the input operand.
-    Value getInput() {
-      return getDpsInputOperand(0)->get();
-    }
-    ShapedType getInputType() {
-      return cast<ShapedType>(getInput().getType());
-    }
-    ShapedType getOutputType() {
-      return cast<ShapedType>(getDest().getType()); // getDest() 사용
-    }
-    int64_t getInputRank() {
-      return getInputType().getRank();
-    }
-    int64_t getOutputRank() {
-      return getOutputType().getRank();
-    }
-    LogicalResult generateScalarImplementation(OpBuilder &builder, Location loc, ValueRange ivs);
-
     static Value createDestinationTensor(OpBuilder &b, Location loc,
         Value source, ArrayRef<OpFoldResult> innerTileSizes,
         ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);



More information about the Mlir-commits mailing list