[Mlir-commits] [mlir] [mlir][linalg] Add support to pass attributes to the packed ops (PR #79526)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 29 11:26:18 PST 2024
https://github.com/yzhang93 updated https://github.com/llvm/llvm-project/pull/79526
>From 343c0efcfa479a851b34dac9deb57bb16ecb66aa Mon Sep 17 00:00:00 2001
From: yzhang93 <zhyuhang88 at gmail.com>
Date: Thu, 25 Jan 2024 15:21:48 -0800
Subject: [PATCH 1/2] [mlir][linalg] Add support to pass attributes to the
packed ops
---
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h | 4 +++-
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 7 +++++--
2 files changed, 8 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index b3397ae131b56f9..79a04e5978bb25c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -369,7 +369,9 @@ struct GenerateLoopNest {
/// Returns an attribute list that excludes pre-defined attributes.
template <typename OpTy>
SmallVector<NamedAttribute> getPrunedAttributeList(OpTy op) {
- auto elidedAttrs = llvm::to_vector(op.getAttributeNames());
+ // op.getAttributeNames() doesn't work when the op is linalg::LinalgOp.
+ // Instead use the static function to get attribute names.
+ auto elidedAttrs = llvm::to_vector(linalg::GenericOp::getAttributeNames());
if (isa<linalg::LinalgOp>(op.getOperation()))
elidedAttrs.push_back(LinalgDialect::kMemoizedIndexingMapsAttrName);
return getPrunedAttributeList(op, elidedAttrs);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 4df105af5bcd6f1..de224cf3ddc2aeb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -604,9 +604,10 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
ValueRange inits =
ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
+ auto prunedAttrs = getPrunedAttributeList(linalgOp);
auto packedLinalgOp = rewriter.create<linalg::GenericOp>(
linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps,
- iteratorTypes);
+ iteratorTypes, /*bodyBuild=*/nullptr, prunedAttrs);
packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
// Step 4. Propagate packing to all the op results.
@@ -685,6 +686,7 @@ static LinalgOp transposeOneLinalgOperandAndReplace(
operands[opOperand.getOperandNumber()] = transposedValue;
ValueRange operandsRef(operands);
+ auto prunedAttrs = getPrunedAttributeList(linalgOp);
auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
/*location=*/linalgOp->getLoc(),
/*resultTensorTypes=*/
@@ -692,7 +694,8 @@ static LinalgOp transposeOneLinalgOperandAndReplace(
/*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
/*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
/*indexingMaps=*/indexingMaps,
- /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
+ /*iteratorTypes=*/linalgOp.getIteratorTypesArray(),
+ /*bodyBuild=*/nullptr, prunedAttrs);
transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
>From 36966c11e38fcb62a952914e1b05a9ef69f27bcc Mon Sep 17 00:00:00 2001
From: yzhang93 <zhyuhang88 at gmail.com>
Date: Mon, 29 Jan 2024 11:20:24 -0800
Subject: [PATCH 2/2] Use partial specialization and add a test
---
.../include/mlir/Dialect/Linalg/Utils/Utils.h | 15 +++++--
.../Dialect/Linalg/transform-op-pack.mlir | 40 +++++++++++++++++++
2 files changed, 52 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 79a04e5978bb25c..ebe2b0ef0339f80 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -367,16 +367,25 @@ struct GenerateLoopNest {
};
/// Returns an attribute list that excludes pre-defined attributes.
+/// If the input is linalg::LinalgOp, there is no method of
+/// `op.getAttributeNames()`. For this special case, using function template
+/// specialization to get attribute names from linalg::GenericOp, because all
+/// Linalg ops have the same attributes as linalg.generic ops.
template <typename OpTy>
SmallVector<NamedAttribute> getPrunedAttributeList(OpTy op) {
- // op.getAttributeNames() doesn't work when the op is linalg::LinalgOp.
- // Instead use the static function to get attribute names.
- auto elidedAttrs = llvm::to_vector(linalg::GenericOp::getAttributeNames());
+ auto elidedAttrs = llvm::to_vector(op.getAttributeNames());
if (isa<linalg::LinalgOp>(op.getOperation()))
elidedAttrs.push_back(LinalgDialect::kMemoizedIndexingMapsAttrName);
return getPrunedAttributeList(op, elidedAttrs);
}
+template <>
+inline SmallVector<NamedAttribute> getPrunedAttributeList(linalg::LinalgOp op) {
+ auto elidedAttrs = llvm::to_vector(linalg::GenericOp::getAttributeNames());
+ elidedAttrs.push_back(LinalgDialect::kMemoizedIndexingMapsAttrName);
+ return getPrunedAttributeList(op, elidedAttrs);
+}
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/transform-op-pack.mlir b/mlir/test/Dialect/Linalg/transform-op-pack.mlir
index cf6339ce3de82e4..f2f8bc4de093048 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pack.mlir
@@ -666,3 +666,43 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+func.func @pass_attribute_pack_transpose(%A: tensor<32x32xf32>, %B: tensor<32x32xf32>, %C: tensor<32x32xf32>)
+ -> tensor<32x32xf32> {
+ %0 = linalg.matmul {test_attribute} ins(%A, %B: tensor<32x32xf32>, tensor<32x32xf32>)
+ outs(%C: tensor<32x32xf32>)
+ -> tensor<32x32xf32>
+ return %0 : tensor<32x32xf32>
+}
+
+// CHECK-LABEL: pass_attribute_pack_transpose
+// CHECK: tensor.pack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [4, 8]
+// CHECK-SAME: into %{{.+}} : tensor<32x32xf32> -> tensor<8x4x4x8xf32>
+// CHECK: tensor.pack %{{.+}} outer_dims_perm = [1, 0]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 8]
+// CHECK-SAME: into %{{.+}} : tensor<32x32xf32> -> tensor<4x4x8x8xf32>
+// CHECK: tensor.pack %{{.+}} inner_dims_pos = [0, 1] inner_tiles = [4, 8]
+// CHECK-SAME: into %{{.+}} : tensor<32x32xf32> -> tensor<8x4x4x8xf32>
+// CHECK: linalg.generic {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}],
+// CHECK: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%{{.*}} : tensor<8x4x4x8xf32>, tensor<4x4x8x8xf32>)
+// CHECK-SAME: outs(%{{.*}} : tensor<8x4x4x8xf32>)
+// CHECK-SAME: attrs = {test_attribute}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.pack %0 packed_sizes = [4, 8, 8]
+ : (!transform.any_op) -> (!transform.op<"linalg.generic">)
+ %pack = transform.get_producer_of_operand %1[1]
+ : (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.pack">)
+ %2, %pack_2, %empty_unpack_2 =
+ transform.structured.pack_transpose %pack with_compute_op(%1)
+ outer_perm = [1, 0] inner_perm = [1, 0]
+ : (!transform.op<"tensor.pack">, !transform.op<"linalg.generic">)
+ -> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.any_op)
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list