[Mlir-commits] [mlir] [mlir][linalg] Add support to pass attributes to the packed ops (PR #79526)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 25 16:03:12 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Vivian (yzhang93)

<details>
<summary>Changes</summary>

There is a use case that we need to pass some attributes to the packed op (linalg.generic) during packing or packing transpose. With this change, the attributes will be preserved in the packedOp or transposedOp.

---
Full diff: https://github.com/llvm/llvm-project/pull/79526.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/Utils/Utils.h (+3-1) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+5-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index b3397ae131b56f9..79a04e5978bb25c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -369,7 +369,9 @@ struct GenerateLoopNest {
 /// Returns an attribute list that excludes pre-defined attributes.
 template <typename OpTy>
 SmallVector<NamedAttribute> getPrunedAttributeList(OpTy op) {
-  auto elidedAttrs = llvm::to_vector(op.getAttributeNames());
+  // op.getAttributeNames() doesn't work when the op is linalg::LinalgOp.
+  // Instead use the static function to get attribute names.
+  auto elidedAttrs = llvm::to_vector(linalg::GenericOp::getAttributeNames());
   if (isa<linalg::LinalgOp>(op.getOperation()))
     elidedAttrs.push_back(LinalgDialect::kMemoizedIndexingMapsAttrName);
   return getPrunedAttributeList(op, elidedAttrs);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 4df105af5bcd6f1..de224cf3ddc2aeb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -604,9 +604,10 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
       ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
   ValueRange inits =
       ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
+  auto prunedAttrs = getPrunedAttributeList(linalgOp);
   auto packedLinalgOp = rewriter.create<linalg::GenericOp>(
       linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps,
-      iteratorTypes);
+      iteratorTypes, /*bodyBuild=*/nullptr, prunedAttrs);
   packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
 
   // Step 4. Propagate packing to all the op results.
@@ -685,6 +686,7 @@ static LinalgOp transposeOneLinalgOperandAndReplace(
   operands[opOperand.getOperandNumber()] = transposedValue;
 
   ValueRange operandsRef(operands);
+  auto prunedAttrs = getPrunedAttributeList(linalgOp);
   auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
       /*location=*/linalgOp->getLoc(),
       /*resultTensorTypes=*/
@@ -692,7 +694,8 @@ static LinalgOp transposeOneLinalgOperandAndReplace(
       /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
       /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
       /*indexingMaps=*/indexingMaps,
-      /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
+      /*iteratorTypes=*/linalgOp.getIteratorTypesArray(),
+      /*bodyBuild=*/nullptr, prunedAttrs);
   transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
   rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/79526


More information about the Mlir-commits mailing list