[Mlir-commits] [mlir] [mlir][linalg] Add support to pass attributes to the packed ops (PR #79526)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 29 11:21:14 PST 2024


https://github.com/yzhang93 updated https://github.com/llvm/llvm-project/pull/79526

>From 343c0efcfa479a851b34dac9deb57bb16ecb66aa Mon Sep 17 00:00:00 2001
From: yzhang93 <zhyuhang88 at gmail.com>
Date: Thu, 25 Jan 2024 15:21:48 -0800
Subject: [PATCH 1/2] [mlir][linalg] Add support to pass attributes to the
 packed ops

---
 mlir/include/mlir/Dialect/Linalg/Utils/Utils.h    | 4 +++-
 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 7 +++++--
 2 files changed, 8 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index b3397ae131b56f..79a04e5978bb25 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -369,7 +369,9 @@ struct GenerateLoopNest {
 /// Returns an attribute list that excludes pre-defined attributes.
 template <typename OpTy>
 SmallVector<NamedAttribute> getPrunedAttributeList(OpTy op) {
-  auto elidedAttrs = llvm::to_vector(op.getAttributeNames());
+  // op.getAttributeNames() doesn't work when the op is linalg::LinalgOp.
+  // Instead use the static function to get attribute names.
+  auto elidedAttrs = llvm::to_vector(linalg::GenericOp::getAttributeNames());
   if (isa<linalg::LinalgOp>(op.getOperation()))
     elidedAttrs.push_back(LinalgDialect::kMemoizedIndexingMapsAttrName);
   return getPrunedAttributeList(op, elidedAttrs);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 4df105af5bcd6f..de224cf3ddc2ae 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -604,9 +604,10 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
       ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
   ValueRange inits =
       ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
+  auto prunedAttrs = getPrunedAttributeList(linalgOp);
   auto packedLinalgOp = rewriter.create<linalg::GenericOp>(
       linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps,
-      iteratorTypes);
+      iteratorTypes, /*bodyBuild=*/nullptr, prunedAttrs);
   packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
 
   // Step 4. Propagate packing to all the op results.
@@ -685,6 +686,7 @@ static LinalgOp transposeOneLinalgOperandAndReplace(
   operands[opOperand.getOperandNumber()] = transposedValue;
 
   ValueRange operandsRef(operands);
+  auto prunedAttrs = getPrunedAttributeList(linalgOp);
   auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
       /*location=*/linalgOp->getLoc(),
       /*resultTensorTypes=*/
@@ -692,7 +694,8 @@ static LinalgOp transposeOneLinalgOperandAndReplace(
       /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
       /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
       /*indexingMaps=*/indexingMaps,
-      /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
+      /*iteratorTypes=*/linalgOp.getIteratorTypesArray(),
+      /*bodyBuild=*/nullptr, prunedAttrs);
   transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
   rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
 

>From 99ba898fc11854e9b340352dfec301b42735a1fe Mon Sep 17 00:00:00 2001
From: yzhang93 <zhyuhang88 at gmail.com>
Date: Mon, 29 Jan 2024 11:20:24 -0800
Subject: [PATCH 2/2] Use partial specialization and add a test

---
 .../include/mlir/Dialect/Linalg/Utils/Utils.h | 15 +++++--
 .../Dialect/Linalg/transform-op-pack.mlir     | 40 +++++++++++++++++++
 2 files changed, 52 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 79a04e5978bb25..16c968ecb3f47f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -367,16 +367,25 @@ struct GenerateLoopNest {
 };
 
 /// Returns an attribute list that excludes pre-defined attributes.
+/// If the input is linalg::LinalgOp, there is no method of `op.getAttributeNames()`.
+/// For this special case, using function template specialization to get attribute
+/// names from linalg::GenericOp, because all Linalg ops have the same attributes
+/// as linalg.generic ops.
 template <typename OpTy>
 SmallVector<NamedAttribute> getPrunedAttributeList(OpTy op) {
-  // op.getAttributeNames() doesn't work when the op is linalg::LinalgOp.
-  // Instead use the static function to get attribute names.
-  auto elidedAttrs = llvm::to_vector(linalg::GenericOp::getAttributeNames());
+  auto elidedAttrs = llvm::to_vector(op.getAttributeNames());
   if (isa<linalg::LinalgOp>(op.getOperation()))
     elidedAttrs.push_back(LinalgDialect::kMemoizedIndexingMapsAttrName);
   return getPrunedAttributeList(op, elidedAttrs);
 }
 
+template <>
+inline SmallVector<NamedAttribute> getPrunedAttributeList(linalg::LinalgOp op) {
+  auto elidedAttrs = llvm::to_vector(linalg::GenericOp::getAttributeNames());
+  elidedAttrs.push_back(LinalgDialect::kMemoizedIndexingMapsAttrName);
+  return getPrunedAttributeList(op, elidedAttrs);
+}
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/test/Dialect/Linalg/transform-op-pack.mlir b/mlir/test/Dialect/Linalg/transform-op-pack.mlir
index cf6339ce3de82e..9f5a9bce5a1dbe 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pack.mlir
@@ -666,3 +666,43 @@ module attributes {transform.with_named_sequence} {
       transform.yield
   }
 }
+
+// -----
+
+func.func @pass_attribute_pack_transpose(%A: tensor<32x32xf32>, %B: tensor<32x32xf32>, %C: tensor<32x32xf32>)
+    -> tensor<32x32xf32> {
+  %0 = linalg.matmul {test_attribute} ins(%A, %B: tensor<32x32xf32>, tensor<32x32xf32>)
+                     outs(%C: tensor<32x32xf32>)
+    -> tensor<32x32xf32>
+  return %0 : tensor<32x32xf32>
+}
+
+// CHECK-LABEL: pass_attribute_pack_transpose
+//       CHECK: tensor.pack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [4, 8]
+//  CHECK-SAME:   into %{{.+}} : tensor<32x32xf32> -> tensor<8x4x4x8xf32>
+//       CHECK: tensor.pack %{{.+}} outer_dims_perm = [1, 0]
+//  CHECK-SAME:   inner_dims_pos = [0, 1] inner_tiles = [8, 8]
+//  CHECK-SAME:   into %{{.+}} : tensor<32x32xf32> -> tensor<4x4x8x8xf32>
+//       CHECK: tensor.pack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [4, 8]
+//  CHECK-SAME:   into %{{.+}} : tensor<32x32xf32> -> tensor<8x4x4x8xf32>
+//       CHECK: linalg.generic {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}],
+//       CHECK:   iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+//  CHECK-SAME: ins(%{{.*}} : tensor<8x4x4x8xf32>, tensor<4x4x8x8xf32>)
+//  CHECK-SAME: outs(%{{.*}} : tensor<8x4x4x8xf32>)
+//  CHECK-SAME: attrs = {test_attribute}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+      %1 = transform.structured.pack %0 packed_sizes = [4, 8, 8]
+        : (!transform.any_op) -> (!transform.op<"linalg.generic">)
+      %pack = transform.get_producer_of_operand %1[1]
+      : (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.pack">)
+      %2, %pack_2, %empty_unpack_2 =
+      transform.structured.pack_transpose %pack with_compute_op(%1)
+      outer_perm = [1, 0] inner_perm = [1, 0]
+       : (!transform.op<"tensor.pack">, !transform.op<"linalg.generic">)
+      -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.any_op)
+      transform.yield
+  }
+}
\ No newline at end of file



More information about the Mlir-commits mailing list