[libcxx-commits] [mlir] [clang] [llvm] [libc] [clang-tools-extra] [compiler-rt] [flang] [lldb] [lld] [openmp] [libcxx] [mlir][transform] Add elementwise criteria to `match.structured.body` (PR #79626)

via libcxx-commits libcxx-commits at lists.llvm.org
Mon Jan 29 09:32:37 PST 2024


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/79626

>From ab475c9ffb7c3562bad4772389e97b82e9f110c0 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 26 Jan 2024 11:55:06 -0600
Subject: [PATCH 1/3] Add elementwise criteria to match.structured.body

---
 .../Linalg/TransformOps/LinalgMatchOps.td     |  4 +++
 .../Linalg/TransformOps/LinalgMatchOps.cpp    |  9 ++++-
 .../Dialect/Linalg/match-ops-interpreter.mlir | 34 +++++++++++++++++++
 .../Dialect/Linalg/match-ops-invalid.mlir     |  2 +-
 4 files changed, 47 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
index 162dd05f93030f..dfeb8ae5d5ddbc 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
@@ -106,6 +106,9 @@ def MatchStructuredBodyOp : Op<Transform_Dialect, "match.structured.body", [
       * `passthrough`: the body of the structured payload op only forwards
         inputs to the outputs (copy or broadcast).
 
+      * `elementwise`: the body of the structured payload op represents an
+        elementwise operation.
+
       * `contraction`: the body of the structured payload op is a contraction
         of the form `<red>(<elem>(bbarg0, bbarg1), bbarg2)` where `<elem>` and
         `<red>` are binary operations whose names are specified in the attribute
@@ -123,6 +126,7 @@ def MatchStructuredBodyOp : Op<Transform_Dialect, "match.structured.body", [
   let arguments = (ins TransformHandleTypeInterface:$operand_handle,
                        OptionalAttr<I64Attr>:$reduction_position,
                        UnitAttr:$passthrough,
+                       UnitAttr:$elementwise,
                        OptionalAttr<StrArrayAttr>:$contraction);
   let assemblyFormat = "$operand_handle attr-dict `:` type($operand_handle)";
   let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 115da4b90e063a..fb18886c16b16d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
@@ -187,6 +188,11 @@ DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
     }
     return DiagnosedSilenceableFailure::success();
   }
+  if (getElementwise()) {
+    if (!isElementwise(linalgOp))
+      return emitSilenceableError() << "not elementwise";
+    return DiagnosedSilenceableFailure::success();
+  }
   if (std::optional<ArrayAttr> contractionOps = getContraction()) {
     Block &body = linalgOp->getRegion(0).front();
     std::string message;
@@ -209,13 +215,14 @@ DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
 
 LogicalResult transform::MatchStructuredBodyOp::verify() {
   int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
-                       getContraction().has_value();
+                       getElementwise() + getContraction().has_value();
 
   if (numOptions > 1) {
     std::string attributeNames;
     llvm::raw_string_ostream os(attributeNames);
     llvm::interleaveComma(ArrayRef<StringAttr>{getReductionPositionAttrName(),
                                                getPassthroughAttrName(),
+                                               getElementwiseAttrName(),
                                                getContractionAttrName()},
                           os);
     return emitOpError() << "only one of {" << os.str() << "} is allowed";
diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index a7353a4c38881e..0efe70a7b9ae1e 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -180,6 +180,40 @@ module attributes { transform.with_named_sequence } {
 
 // -----
 
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @print_elementwise(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "elementwise" : !transform.any_op
+    transform.yield
+  }
+
+  transform.named_sequence @match_structured_body_elementwise(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+    %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
+    ^bb0(%arg1: !transform.any_op):
+      transform.match.structured.body %arg1 { elementwise } : !transform.any_op
+      transform.match.structured.yield %arg1 : !transform.any_op
+    }
+    transform.yield %0 : !transform.any_op
+  }
+
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
+    transform.foreach_match in %arg0
+        @match_structured_body_elementwise -> @print_elementwise
+        : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+
+  func.func @payload(%in1: tensor<2xf32>, %in2: tensor<2xf32>, %out: tensor<2xf32>) -> tensor<2xf32> attributes { transform.target_tag = "start_here" } {
+    %cst0 = arith.constant 0.0 : f32
+    // expected-remark @below {{elementwise}}
+    %fill = linalg.fill ins(%cst0: f32) outs(%out: tensor<2xf32>) -> tensor<2xf32>
+    // expected-remark @below {{elementwise}}
+    %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<2xf32>, tensor<2xf32>) outs(%fill: tensor<2xf32>)
+    return %add : tensor<2xf32>
+  }
+}
+
+// -----
+
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @print_reduction(%arg0: !transform.any_op {transform.readonly}) {
     transform.debug.emit_remark_at %arg0, "reduction" : !transform.any_op
diff --git a/mlir/test/Dialect/Linalg/match-ops-invalid.mlir b/mlir/test/Dialect/Linalg/match-ops-invalid.mlir
index ec99e205090c4c..9ff430a3503606 100644
--- a/mlir/test/Dialect/Linalg/match-ops-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-invalid.mlir
@@ -64,7 +64,7 @@ transform.sequence failures(suppress) {
 ^bb0(%arg0: !transform.any_op):
   transform.match.structured %arg0 : !transform.any_op {
   ^bb1(%arg1: !transform.any_op):
-    // expected-error @below {{only one of {"reduction_position", "passthrough", "contraction"} is allowed}}
+    // expected-error @below {{only one of {"reduction_position", "passthrough", "elementwise", "contraction"} is allowed}}
     transform.match.structured.body %arg1 { passthrough, reduction_position = 0 } : !transform.any_op
     transform.match.structured.yield
   }

>From a1cb4dfafcc64c51409d67e6396b93320508af99 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 26 Jan 2024 14:48:19 -0600
Subject: [PATCH 2/3] Add brodcast elementwise test

---
 .../Dialect/Linalg/match-ops-interpreter.mlir     | 15 ++++++++++++---
 1 file changed, 12 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index 0efe70a7b9ae1e..6e05c6e17de18b 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -182,7 +182,7 @@ module attributes { transform.with_named_sequence } {
 
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @print_elementwise(%arg0: !transform.any_op {transform.readonly}) {
-    transform.test_print_remark_at_operand %arg0, "elementwise" : !transform.any_op
+    transform.debug.emit_remark_at %arg0, "elementwise" : !transform.any_op
     transform.yield
   }
 
@@ -202,13 +202,22 @@ module attributes { transform.with_named_sequence } {
     transform.yield
   }
 
-  func.func @payload(%in1: tensor<2xf32>, %in2: tensor<2xf32>, %out: tensor<2xf32>) -> tensor<2xf32> attributes { transform.target_tag = "start_here" } {
+  func.func @payload(%in1: tensor<2xf32>, %in2: tensor<2xf32>, %in3: tensor<2x3xf32>, %out: tensor<2xf32>, %out2: tensor<2x3xf32>) -> (tensor<2xf32>, tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } {
     %cst0 = arith.constant 0.0 : f32
     // expected-remark @below {{elementwise}}
     %fill = linalg.fill ins(%cst0: f32) outs(%out: tensor<2xf32>) -> tensor<2xf32>
     // expected-remark @below {{elementwise}}
     %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<2xf32>, tensor<2xf32>) outs(%fill: tensor<2xf32>)
-    return %add : tensor<2xf32>
+    // expected-remark @below {{elementwise}}
+    %add_bcast = linalg.generic
+      {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+       iterator_types = ["parallel", "parallel"]}
+      ins(%in1, %in3: tensor<2xf32>, tensor<2x3xf32>) outs(%out2: tensor<2x3xf32>) {
+        ^bb0(%arg0: f32, %arg1: f32, %arg3: f32):
+          %0 = arith.addf %arg0, %arg1 : f32
+          linalg.yield %0 : f32
+      } -> tensor<2x3xf32>
+    return %add, %add_bcast : tensor<2xf32>, tensor<2x3xf32>
   }
 }
 

>From bd1a89f888060d94c3326e2218bbfb2d9bda24c1 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 26 Jan 2024 15:02:29 -0600
Subject: [PATCH 3/3] Add non-elementwise test

---
 .../Dialect/Linalg/match-ops-interpreter.mlir  | 18 ++++++++++++++++--
 1 file changed, 16 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index 6e05c6e17de18b..24c7bdd9e1050e 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -202,12 +202,26 @@ module attributes { transform.with_named_sequence } {
     transform.yield
   }
 
-  func.func @payload(%in1: tensor<2xf32>, %in2: tensor<2xf32>, %in3: tensor<2x3xf32>, %out: tensor<2xf32>, %out2: tensor<2x3xf32>) -> (tensor<2xf32>, tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } {
+  func.func @payload(%in1: tensor<2xf32>, %in2: tensor<2xf32>, %in3: tensor<2x3xf32>, %out: tensor<2xf32>, %out2: tensor<2x3xf32>) -> (tensor<2xf32>, tensor<2x3xf32>, tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } {
     %cst0 = arith.constant 0.0 : f32
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
     // expected-remark @below {{elementwise}}
     %fill = linalg.fill ins(%cst0: f32) outs(%out: tensor<2xf32>) -> tensor<2xf32>
     // expected-remark @below {{elementwise}}
     %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<2xf32>, tensor<2xf32>) outs(%fill: tensor<2xf32>)
+    %non_elementwise = linalg.generic
+      {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+       iterator_types = ["parallel", "parallel"]}
+      ins(%in1, %in3: tensor<2xf32>, tensor<2x3xf32>) outs(%out2: tensor<2x3xf32>) {
+        ^bb0(%arg0: f32, %arg1: f32, %arg3: f32):
+          %0 = arith.addf %arg0, %arg1 : f32
+          %1 = tensor.dim %add, %c0 : tensor<2xf32>
+          %2 = arith.subi %1, %c1 : index
+          %3 = tensor.extract %add[%2] : tensor<2xf32>
+          %4 = arith.mulf %0, %3 : f32
+          linalg.yield %4 : f32
+      } -> tensor<2x3xf32>
     // expected-remark @below {{elementwise}}
     %add_bcast = linalg.generic
       {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
@@ -217,7 +231,7 @@ module attributes { transform.with_named_sequence } {
           %0 = arith.addf %arg0, %arg1 : f32
           linalg.yield %0 : f32
       } -> tensor<2x3xf32>
-    return %add, %add_bcast : tensor<2xf32>, tensor<2x3xf32>
+    return %add, %add_bcast, %non_elementwise : tensor<2xf32>, tensor<2x3xf32>, tensor<2x3xf32>
   }
 }
 



More information about the libcxx-commits mailing list