[Mlir-commits] [mlir] [mlir][linalg] Expose data layout propagation patterns via transform op (PR #184151)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 2 07:12:40 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Adam Siemieniuk (adam-smnk)

<details>
<summary>Changes</summary>

Enables access to data layout propagation patterns from transform schedules by providing transform operation patterns wrapper.

---
Full diff: https://github.com/llvm/llvm-project/pull/184151.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+12) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+9) 
- (added) mlir/test/Dialect/Linalg/transform-op-data-layout-propagation.mlir (+57) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 70d424bae9285..87c876f05c3bd 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -143,6 +143,18 @@ def ApplyFoldPackUnpackIntoEmptyPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyDataLayoutPropagationPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.linalg.data_layout_propagation",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collection of patterns to bubble up or down data layout ops across other
+    operations.
+  }];
+
+  let arguments = (ins DefaultValuedAttr<BoolAttr, "false">:$poison_padding);
+  let assemblyFormat = "attr-dict";
+}
+
 //===----------------------------------------------------------------------===//
 // BufferizeToAllocationOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index e945a15476b3a..f3b7fccee0892 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -272,6 +272,15 @@ void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
   linalg::populateFoldPackUnpackIntoTensorEmptyPatterns(patterns);
 }
 
+void transform::ApplyDataLayoutPropagationPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  ControlPropagationFn defaultControlFn = [](OpOperand *operand) {
+    return true;
+  };
+  linalg::populateDataLayoutPropagationPatterns(patterns, defaultControlFn,
+                                                getPoisonPadding());
+}
+
 //===----------------------------------------------------------------------===//
 // BufferizeToAllocationOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-op-data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/transform-op-data-layout-propagation.mlir
new file mode 100644
index 0000000000000..301537507c972
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-data-layout-propagation.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
+
+module @transforms attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %funcs = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
+    transform.foreach %funcs : !transform.any_op {
+    ^bb0(%func: !transform.any_op):
+      transform.apply_patterns to %func {
+        transform.apply_patterns.linalg.data_layout_propagation {poison_padding = false}
+      } : !transform.any_op
+      transform.yield
+    }
+    transform.yield
+  }
+}
+
+func.func @no_propagation_without_poison(%arg0: tensor<8x8x4x8xf32>, %dest: tensor<?x64xf32>, %arg1: tensor<?x64xbf16>) -> tensor<?x64xbf16> {
+  %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %dest : tensor<8x8x4x8xf32> -> tensor<?x64xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<?x64xf32>) outs(%arg1 : tensor<?x64xbf16>) {
+  ^bb0(%in: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    linalg.yield %1 : bf16
+  } -> tensor<?x64xbf16>
+  return %0 : tensor<?x64xbf16>
+}
+// CHECK-LABEL:  func.func @no_propagation_without_poison
+// CHECK:          %[[UNPACK:.+]] = linalg.unpack
+// CHECK:          linalg.generic{{.*}}ins(%[[UNPACK]]
+
+// -----
+
+module @transforms attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %funcs = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
+    transform.foreach %funcs : !transform.any_op {
+    ^bb0(%func: !transform.any_op):
+      transform.apply_patterns to %func {
+        transform.apply_patterns.linalg.data_layout_propagation {poison_padding = true}
+      } : !transform.any_op
+      transform.yield
+    }
+    transform.yield
+  }
+}
+
+func.func @propagation_with_poison(%arg0: tensor<8x8x4x8xf32>, %dest: tensor<?x64xf32>, %arg1: tensor<?x64xbf16>) -> tensor<?x64xbf16> {
+  %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %dest : tensor<8x8x4x8xf32> -> tensor<?x64xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%unpack : tensor<?x64xf32>) outs(%arg1 : tensor<?x64xbf16>) {
+  ^bb0(%in: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    linalg.yield %1 : bf16
+  } -> tensor<?x64xbf16>
+  return %0 : tensor<?x64xbf16>
+}
+// CHECK-LABEL:  func.func @propagation_with_poison
+// CHECK:          %[[GENERIC:.+]] = linalg.generic
+// CHECK:          linalg.unpack %[[GENERIC]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/184151


More information about the Mlir-commits mailing list