[Mlir-commits] [mlir] [mlir][linalg] Add support to pass attributes to the packed ops (PR #79526)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 25 16:02:45 PST 2024


https://github.com/yzhang93 created https://github.com/llvm/llvm-project/pull/79526

There is a use case that we need to pass some attributes to the packed op (linalg.generic) during packing or packing transpose. With this change, the attributes will be preserved in the packedOp or transposedOp.

>From 343c0efcfa479a851b34dac9deb57bb16ecb66aa Mon Sep 17 00:00:00 2001
From: yzhang93 <zhyuhang88 at gmail.com>
Date: Thu, 25 Jan 2024 15:21:48 -0800
Subject: [PATCH] [mlir][linalg] Add support to pass attributes to the packed
 ops

---
 mlir/include/mlir/Dialect/Linalg/Utils/Utils.h    | 4 +++-
 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 7 +++++--
 2 files changed, 8 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index b3397ae131b56f9..79a04e5978bb25c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -369,7 +369,9 @@ struct GenerateLoopNest {
 /// Returns an attribute list that excludes pre-defined attributes.
 template <typename OpTy>
 SmallVector<NamedAttribute> getPrunedAttributeList(OpTy op) {
-  auto elidedAttrs = llvm::to_vector(op.getAttributeNames());
+  // op.getAttributeNames() doesn't work when the op is linalg::LinalgOp.
+  // Instead use the static function to get attribute names.
+  auto elidedAttrs = llvm::to_vector(linalg::GenericOp::getAttributeNames());
   if (isa<linalg::LinalgOp>(op.getOperation()))
     elidedAttrs.push_back(LinalgDialect::kMemoizedIndexingMapsAttrName);
   return getPrunedAttributeList(op, elidedAttrs);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 4df105af5bcd6f1..de224cf3ddc2aeb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -604,9 +604,10 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
       ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
   ValueRange inits =
       ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
+  auto prunedAttrs = getPrunedAttributeList(linalgOp);
   auto packedLinalgOp = rewriter.create<linalg::GenericOp>(
       linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps,
-      iteratorTypes);
+      iteratorTypes, /*bodyBuild=*/nullptr, prunedAttrs);
   packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
 
   // Step 4. Propagate packing to all the op results.
@@ -685,6 +686,7 @@ static LinalgOp transposeOneLinalgOperandAndReplace(
   operands[opOperand.getOperandNumber()] = transposedValue;
 
   ValueRange operandsRef(operands);
+  auto prunedAttrs = getPrunedAttributeList(linalgOp);
   auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
       /*location=*/linalgOp->getLoc(),
       /*resultTensorTypes=*/
@@ -692,7 +694,8 @@ static LinalgOp transposeOneLinalgOperandAndReplace(
       /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
       /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
       /*indexingMaps=*/indexingMaps,
-      /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
+      /*iteratorTypes=*/linalgOp.getIteratorTypesArray(),
+      /*bodyBuild=*/nullptr, prunedAttrs);
   transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
   rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
 



More information about the Mlir-commits mailing list