[Mlir-commits] [mlir] [mlir][linalg] Expose data layout propagation patterns via transform op (PR #184151)
Adam Siemieniuk
llvmlistbot at llvm.org
Mon Mar 2 07:12:03 PST 2026
https://github.com/adam-smnk created https://github.com/llvm/llvm-project/pull/184151
Enables access to data layout propagation patterns from transform schedules by providing transform operation patterns wrapper.
>From 8963484ccdf6f8ee207ac3d554f2966e9a356532 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 2 Mar 2026 16:03:41 +0100
Subject: [PATCH] [mlir][linalg] Expose data layout propagation patterns via
transform op
Enables access to data layout propagation patterns from transform
schedules by providing transform operation patterns wrapper.
---
.../Linalg/TransformOps/LinalgTransformOps.td | 12 ++++
.../TransformOps/LinalgTransformOps.cpp | 9 +++
.../transform-op-data-layout-propagation.mlir | 57 +++++++++++++++++++
3 files changed, 78 insertions(+)
create mode 100644 mlir/test/Dialect/Linalg/transform-op-data-layout-propagation.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 70d424bae9285..87c876f05c3bd 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -143,6 +143,18 @@ def ApplyFoldPackUnpackIntoEmptyPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyDataLayoutPropagationPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.linalg.data_layout_propagation",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collection of patterns to bubble up or down data layout ops across other
+ operations.
+ }];
+
+ let arguments = (ins DefaultValuedAttr<BoolAttr, "false">:$poison_padding);
+ let assemblyFormat = "attr-dict";
+}
+
//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index e945a15476b3a..f3b7fccee0892 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -272,6 +272,15 @@ void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
linalg::populateFoldPackUnpackIntoTensorEmptyPatterns(patterns);
}
+void transform::ApplyDataLayoutPropagationPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ ControlPropagationFn defaultControlFn = [](OpOperand *operand) {
+ return true;
+ };
+ linalg::populateDataLayoutPropagationPatterns(patterns, defaultControlFn,
+ getPoisonPadding());
+}
+
//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-op-data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/transform-op-data-layout-propagation.mlir
new file mode 100644
index 0000000000000..301537507c972
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-data-layout-propagation.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
+
+module @transforms attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %funcs = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
+ transform.foreach %funcs : !transform.any_op {
+ ^bb0(%func: !transform.any_op):
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.data_layout_propagation {poison_padding = false}
+ } : !transform.any_op
+ transform.yield
+ }
+ transform.yield
+ }
+}
+
+func.func @no_propagation_without_poison(%arg0: tensor<8x8x4x8xf32>, %dest: tensor<?x64xf32>, %arg1: tensor<?x64xbf16>) -> tensor<?x64xbf16> {
+ %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %dest : tensor<8x8x4x8xf32> -> tensor<?x64xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<?x64xf32>) outs(%arg1 : tensor<?x64xbf16>) {
+ ^bb0(%in: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<?x64xbf16>
+ return %0 : tensor<?x64xbf16>
+}
+// CHECK-LABEL: func.func @no_propagation_without_poison
+// CHECK: %[[UNPACK:.+]] = linalg.unpack
+// CHECK: linalg.generic{{.*}}ins(%[[UNPACK]]
+
+// -----
+
+module @transforms attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %funcs = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
+ transform.foreach %funcs : !transform.any_op {
+ ^bb0(%func: !transform.any_op):
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.data_layout_propagation {poison_padding = true}
+ } : !transform.any_op
+ transform.yield
+ }
+ transform.yield
+ }
+}
+
+func.func @propagation_with_poison(%arg0: tensor<8x8x4x8xf32>, %dest: tensor<?x64xf32>, %arg1: tensor<?x64xbf16>) -> tensor<?x64xbf16> {
+ %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %dest : tensor<8x8x4x8xf32> -> tensor<?x64xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<?x64xf32>) outs(%arg1 : tensor<?x64xbf16>) {
+ ^bb0(%in: f32, %out: bf16):
+ %1 = arith.truncf %in : f32 to bf16
+ linalg.yield %1 : bf16
+ } -> tensor<?x64xbf16>
+ return %0 : tensor<?x64xbf16>
+}
+// CHECK-LABEL: func.func @propagation_with_poison
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK: linalg.unpack %[[GENERIC]]
More information about the Mlir-commits
mailing list