[Mlir-commits] [mlir] 4c7225d - [mlir][Transform] Fix implementation of the generic apply that is based on applyToOne.

Nicolas Vasilache llvmlistbot at llvm.org
Thu Jun 23 05:29:28 PDT 2022


Author: Nicolas Vasilache
Date: 2022-06-23T05:28:09-07:00
New Revision: 4c7225d19a9d1ff62c0ae39049ca3afe2a46c571

URL: https://github.com/llvm/llvm-project/commit/4c7225d19a9d1ff62c0ae39049ca3afe2a46c571
DIFF: https://github.com/llvm/llvm-project/commit/4c7225d19a9d1ff62c0ae39049ca3afe2a46c571.diff

LOG: [mlir][Transform] Fix implementation of the generic apply that is based on applyToOne.

The result of applying an N-result producing transformation to M payload ops
is an M-wide result, each containing N result operations.
This requires a transposition of the results obtained by calling `applyToOne`.

This revision fixes the issue and adds more advanced tests that exercise the behavior.

Differential Revision: https://reviews.llvm.org/D128414

Added: 
    mlir/test/Dialect/Transform/selective-targeting.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 8f0dc16d35ab7..f3e42cefb2d45 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -348,7 +348,7 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
     does **not** fail when no ops were vectorized.
 
     Note that this transformation is invalidating the handles to any payload IR
-    operation that is contained inside the vectoriztaion target.
+    operation that is contained inside the vectorization target.
   }];
 
   let arguments = (ins PDL_Operation:$target,

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 390d6c6240657..e6fbfc88e31ee 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -794,6 +794,24 @@ applyTransformToEach(ArrayRef<Operation *> targets,
   }
   return DiagnosedSilenceableFailure::success();
 }
+
+/// Helper function to transform M ops with N results into N results of M ops.
+static inline SmallVector<SmallVector<Operation *, 1>>
+transposeResults(const SmallVector<SmallVector<Operation *>, 1> &m) {
+  SmallVector<SmallVector<Operation *, 1>> res;
+  if (m.empty())
+    return res;
+  int64_t rows = m.size(), cols = m[0].size();
+  for (int64_t j = 0; j < cols; ++j)
+    res.push_back(SmallVector<Operation *, 1>(rows, nullptr));
+  for (int64_t i = 0; i < rows; ++i) {
+    assert(static_cast<int64_t>(m[i].size()) == cols);
+    for (int64_t j = 0; j < cols; ++j) {
+      res[j][i] = m[i][j];
+    }
+  }
+  return res;
+}
 } // namespace detail
 } // namespace transform
 } // namespace mlir
@@ -815,27 +833,51 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
       });
   if (!result.succeeded())
     return result;
-  for (const SmallVector<Operation *> &oneTargetResults : results) {
-    if (OpTy::template hasTrait<OpTrait::ZeroResults>())
-      continue;
-    if (OpTy::template hasTrait<OpTrait::OneResult>()) {
-      transformResults.set(
-          this->getOperation()->getResult(0).template cast<OpResult>(),
-          oneTargetResults);
-      continue;
-    }
-    if (this->getOperation()->getNumResults() != oneTargetResults.size()) {
-      Diagnostic diag(this->getOperation()->getLoc(),
-                      DiagnosticSeverity::Error);
-      diag << "unexpected number of results (got " << oneTargetResults.size()
-           << " expected " << this->getOperation()->getNumResults() << ")";
-      return DiagnosedSilenceableFailure::silencableFailure(std::move(diag));
-    }
-    for (const auto &it :
-         llvm::zip(this->getOperation()->getResults(), oneTargetResults)) {
-      transformResults.set(std::get<0>(it).template cast<OpResult>(),
-                           std::get<1>(it));
-    }
+  if (results.empty())
+    return DiagnosedSilenceableFailure::success();
+
+  // Ensure all applications return the same number of results.
+  // Variadic cases are much trickier to handle in a generic fashion.
+  int64_t nRes = results[0].size();
+  if (llvm::any_of(results, [&](const auto &r) {
+        return static_cast<int64_t>(r.size()) != nRes;
+      })) {
+    return static_cast<OpTy *>(this)->emitSilenceableError()
+           << "expected all applications of " << OpTy::getOperationName()
+           << " to produce " << nRes
+           << " results.\n If you need variadic results, consider using a "
+              "generic `apply` instead of the specialized `applyToOne`";
+  }
+  // Ensure the number of results agrees with what the transform op expects.
+  if (this->getOperation()->getNumResults() != nRes) {
+    InFlightDiagnostic diag = static_cast<OpTy *>(this)->emitError()
+                              << "unexpected number of results (got " << nRes
+                              << " expected "
+                              << this->getOperation()->getNumResults() << ")";
+    return DiagnosedSilenceableFailure::definiteFailure();
+  }
+
+  // If no results, bail early.
+  if (OpTy::template hasTrait<OpTrait::ZeroResults>())
+    return DiagnosedSilenceableFailure::success();
+
+  // Perform transposition of M applications producing N results each into N
+  // results for each of the M applications.
+  SmallVector<SmallVector<Operation *, 1>> transposedResults =
+      detail::transposeResults(results);
+  // Single result applies to M ops produces one single M-result.
+  if (OpTy::template hasTrait<OpTrait::OneResult>()) {
+    assert(transposedResults.size() == 1 && "Expected single result");
+    transformResults.set(
+        this->getOperation()->getResult(0).template cast<OpResult>(),
+        transposedResults[0]);
+    return DiagnosedSilenceableFailure::success();
+  }
+  // M ops, N results each.
+  for (const auto &it :
+       llvm::zip(this->getOperation()->getResults(), transposedResults)) {
+    transformResults.set(std::get<0>(it).template cast<OpResult>(),
+                         std::get<1>(it));
   }
   return DiagnosedSilenceableFailure::success();
 }

diff  --git a/mlir/test/Dialect/Transform/selective-targeting.mlir b/mlir/test/Dialect/Transform/selective-targeting.mlir
new file mode 100644
index 0000000000000..eee4c2bf54456
--- /dev/null
+++ b/mlir/test/Dialect/Transform/selective-targeting.mlir
@@ -0,0 +1,154 @@
+// RUN:  mlir-opt %s -test-transform-dialect-interpreter --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @matmul_tensors_1(
+func.func @matmul_tensors_1(
+  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
+  %arg2: tensor<128x128xf32> {linalg.inplaceable = true})
+    -> tensor<128x128xf32> {
+  // This operation is marked for tiling only.
+  // CHECK-COUNT-3: scf.for
+  // CHECK-COUNT-3: tensor.extract_slice
+  // CHECK: linalg.matmul
+  // CHECK-SAME: -> tensor<4x4xf32>
+  %0 = linalg.matmul { test.attrA }
+                      ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%arg2: tensor<128x128xf32>)
+    -> tensor<128x128xf32>
+  func.return %0 : tensor<128x128xf32>
+}
+
+func.func @matmul_tensors_2(
+  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
+  %arg2: tensor<128x128xf32> {linalg.inplaceable = true})
+    -> tensor<128x128xf32> {
+  // This operation is marked f
+  // This operation is marked for tiling and vectorization.
+  // CHECK-COUNT-3: scf.for
+  // CHECK-COUNT-3: vector.transfer_read
+  // CHECK:       vector.contract
+  // CHECK-NOT:   linalg.matmul
+  // CHECK:       vector.transfer_write
+  %0 = linalg.matmul { test.attrA, test.attrC }
+                      ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%arg2: tensor<128x128xf32>)
+    -> tensor<128x128xf32>
+  func.return %0 : tensor<128x128xf32>
+}
+
+func.func @matmul_tensors_3(
+  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
+  %arg2: tensor<128x128xf32> {linalg.inplaceable = true})
+    -> tensor<128x128xf32> {
+  // This operation is marked for vectorization only.
+  // CHECK-NOT: scf.for
+  // CHECK-COUNT-3: vector.transfer_read
+  // CHECK: vector.contract
+  // CHECK-SAME: into vector<128x128xf32>
+  // CHECK: vector.transfer_write
+  %0 = linalg.matmul { test.attrC }
+                      ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%arg2: tensor<128x128xf32>)
+    -> tensor<128x128xf32>
+  func.return %0 : tensor<128x128xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  // Match matmul operations inside @matmul_tensors with test.attrA set.
+  pdl.pattern @pdl_target_attrA : benefit(1) {
+    %args = operands
+    %results = types
+    %attr = attribute
+    %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrA" = %attr}-> (%results : !pdl.range<type>)
+    // TODO: we don't want this, but it is the required terminator for pdl.pattern
+    rewrite %0 with "transform.dialect"
+  }
+
+  // Match matmul operations inside @matmul_tensors with test.attrC set.
+  pdl.pattern @pdl_target_attrC : benefit(1) {
+    %args = operands
+    %results = types
+    %attr = attribute
+    %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrC" = %attr}-> (%results : !pdl.range<type>)
+    // TODO: we don't want this, but it is the required terminator for pdl.pattern
+    rewrite %0 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @pdl_target_attrA in %arg1
+    transform.structured.tile %0 {sizes = [4, 4, 4]}
+    %1 = pdl_match @pdl_target_attrC in %arg1
+    %2 = transform.get_closest_isolated_parent %1
+    transform.structured.vectorize %2
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @vectorize_one
+func.func @vectorize_one(
+  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
+  %arg2: tensor<128x128xf32> {linalg.inplaceable = true})
+    -> tensor<128x128xf32> {
+  // CHECK: vector.contract
+  %0 = linalg.matmul {test.attrA}
+                     ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%arg2: tensor<128x128xf32>)
+    -> tensor<128x128xf32>
+  func.return %0 : tensor<128x128xf32>
+}
+
+func.func @vectorize_none(
+  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
+  %arg2: tensor<128x128xf32> {linalg.inplaceable = true})
+    -> tensor<128x128xf32> {
+  // CHECK: linalg.matmul
+  %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%arg2: tensor<128x128xf32>)
+    -> tensor<128x128xf32>
+  func.return %0 : tensor<128x128xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @pdl_target : benefit(1) {
+    %args = operands
+    %results = types
+    %attr = attribute
+    %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrA" = %attr}-> (%results : !pdl.range<type>)
+    // TODO: we don't want this, but it is the required terminator for pdl.pattern
+    rewrite %0 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @pdl_target in %arg1
+    %1 = get_closest_isolated_parent %0
+    transform.structured.vectorize %1
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @vectorize_all
+func.func @vectorize_all(
+  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>,
+  %arg3: tensor<128x128xf32> {linalg.inplaceable = true})
+    -> tensor<128x128xf32> {
+  // CHECK: vector.contract
+  %0 = linalg.matmul {test.attrA}
+                     ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%arg2: tensor<128x128xf32>)
+    -> tensor<128x128xf32>
+  // CHECK: vector.contract
+  %1 = linalg.matmul ins(%arg0, %0: tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%arg3: tensor<128x128xf32>)
+    -> tensor<128x128xf32>
+  return %1 : tensor<128x128xf32>
+}
+
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  transform.structured.vectorize %arg0
+}

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 6a54b0ce6f59e..a487d1dbef19c 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
 
 // expected-remark @below {{applying transformation}}
 transform.test_transform_op
@@ -385,3 +385,54 @@ transform.sequence {
   // expected-error @below {{unexpected number of results (got 0 expected 3)}}
   transform.test_wrong_number_of_results %arg0
 }
+
+// -----
+
+func.func @foo() {
+  "op" () : () -> ()
+  "op" () : () -> ()
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @some : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "op"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = pdl_match @some in %arg1
+    // expected-error @below {{expected all applications of transform.test_wrong_number_of_multi_results to produce 1 results}}
+    transform.test_wrong_number_of_multi_results %0
+  }
+}
+
+// -----
+
+func.func @foo() {
+  "op" () : () -> ()
+  "op" () : () -> ()
+  "op" () : () -> ()
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @some : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "op"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = pdl_match @some in %arg1
+    // Transform matches 3 ops and produces 2 results.
+    %1:2 = transform.test_correct_number_of_multi_results %0
+  }
+}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index c48a7936adb12..9242a51dbb840 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -232,6 +232,24 @@ mlir::test::TestWrongNumberOfResultsOp::applyToOne(
   return SmallVector<Operation *>{};
 }
 
+FailureOr<SmallVector<Operation *>>
+mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne(
+    Operation *op, transform::TransformState &state) {
+  static int count = 0;
+  if (count++ > 0)
+    return SmallVector<Operation *>{};
+  OperationState opState(op->getLoc(), "foo");
+  return SmallVector<Operation *>{OpBuilder(op).create(opState)};
+}
+
+FailureOr<SmallVector<Operation *>>
+mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne(
+    Operation *op, transform::TransformState &state) {
+  OperationState opState(op->getLoc(), "foo");
+  return SmallVector<Operation *>{OpBuilder(op).create(opState),
+                                  OpBuilder(op).create(opState)};
+}
+
 namespace {
 /// Test extension of the Transform dialect. Registers additional ops and
 /// declares PDL as dependent dialect since the additional ops are using PDL

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 1b8ddb9649c34..78eade06d0d65 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -144,4 +144,33 @@ def TestWrongNumberOfResultsOp
   }];
 }
 
+def TestWrongNumberOfMultiResultsOp
+  : Op<Transform_Dialect, "test_wrong_number_of_multi_results",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, 
+     TransformEachOpTrait, TransformOpInterface]> {
+  let arguments = (ins PDL_Operation:$target);
+  let results = (outs PDL_Operation:$result);
+  let assemblyFormat = "$target attr-dict";
+  let cppNamespace = "::mlir::test";
+  let extraClassDeclaration = [{
+    ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
+        ::mlir::Operation *target, transform::TransformState &state);
+  }];
+}
+
+def TestCorrectNumberOfMultiResultsOp
+  : Op<Transform_Dialect, "test_correct_number_of_multi_results",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, 
+     TransformEachOpTrait, TransformOpInterface]> {
+  let arguments = (ins PDL_Operation:$target);
+  let results = (outs PDL_Operation:$result1,
+                      PDL_Operation:$result2);
+  let assemblyFormat = "$target attr-dict";
+  let cppNamespace = "::mlir::test";
+  let extraClassDeclaration = [{
+    ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
+        ::mlir::Operation *target, transform::TransformState &state);
+  }];
+}
+
 #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD


        


More information about the Mlir-commits mailing list