[Mlir-commits] [mlir] 83b582d - [mlir][Linalg] Properly propagate transform result in ScalarizeOp
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Dec 27 06:17:02 PST 2022
Author: Nicolas Vasilache
Date: 2022-12-27T06:16:55-08:00
New Revision: 83b582d51b742ad4a3e2b10e55058508b0e1ebc6
URL: https://github.com/llvm/llvm-project/commit/83b582d51b742ad4a3e2b10e55058508b0e1ebc6
DIFF: https://github.com/llvm/llvm-project/commit/83b582d51b742ad4a3e2b10e55058508b0e1ebc6.diff
LOG: [mlir][Linalg] Properly propagate transform result in ScalarizeOp
Added:
Modified:
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-op-scalarize.mlir
mlir/test/Dialect/Linalg/transform-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 347c53085aa58..5660891d56529 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -67,14 +67,14 @@ DiagnosedSilenceableFailure
transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
-#define DOWNSCALE(trans) \
- { \
- FailureOr<LinalgOp> res = tryApply<trans>(target); \
- if (succeeded(res)) { \
- results.push_back(*res); \
- return DiagnosedSilenceableFailure::success(); \
- } \
- }
+#define DOWNSCALE(trans) \
+ { \
+ FailureOr<LinalgOp> res = tryApply<trans>(target); \
+ if (succeeded(res)) { \
+ results.push_back(*res); \
+ return DiagnosedSilenceableFailure::success(); \
+ } \
+ }
#define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
#define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
@@ -986,6 +986,10 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
if (failed(maybeTilingResult))
return emitDefaultDefiniteFailure(target);
+ if (target->getNumResults())
+ rewriter.replaceOp(target, maybeTilingResult->replacements);
+ else
+ rewriter.eraseOp(target);
results.append(maybeTilingResult->tiledOps);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir b/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir
index 89c8d3265373b..fbf083c3d1ad8 100644
--- a/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir
@@ -5,8 +5,16 @@ func.func @scalarize(%arg0: tensor<24x12xf32>,
%arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
// The op is first tiled by 10 in the first dimension, which creates a
// dynamic size, and then scalarized, which brings the dimension to static 1.
- // CHECK: linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<1x12
+ // CHECK: %[[RES_LOOP_1:.*]] = scf.for {{.*}} -> (tensor<24x25xf32>)
+ // CHECK: %[[RES_LOOP_2:.*]] = scf.for {{.*}} -> (tensor<?x25xf32>)
+ // CHECK: %[[MM:.*]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<1x12
+ // CHECK: %[[INS_2:.*]] = tensor.insert_slice %[[MM]] into %{{.*}} [1, 25] [1, 1] : tensor<1x25xf32> into tensor<?x25xf32>
+ // CHECK: scf.yield %[[INS_2]] : tensor<?x25xf32>
+ // CHECK: %[[INS_1:.*]] = tensor.insert_slice %[[RES_LOOP_2]] into %{{.*}}, 25] [1, 1] : tensor<?x25xf32> into tensor<24x25xf32>
+ // CHECK: scf.yield %[[INS_1]] : tensor<24x25xf32>
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+
+ // CHECK: return %[[RES_LOOP_1]] : tensor<24x25xf32>
func.return %0 : tensor<24x25xf32>
}
diff --git a/mlir/test/Dialect/Linalg/transform-ops.mlir b/mlir/test/Dialect/Linalg/transform-ops.mlir
index 898cce730e536..64cf3fbd04f90 100644
--- a/mlir/test/Dialect/Linalg/transform-ops.mlir
+++ b/mlir/test/Dialect/Linalg/transform-ops.mlir
@@ -8,7 +8,7 @@ transform.sequence failures(propagate) {
//===----------------------------------------------------------------------===//
// Check that operations are registered correctly through the extension
-// mechanism. Their syntax is generated and requries no additional testing since
+// mechanism. Their syntax is generated and requires no additional testing since
// we test the generator.
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list