[Mlir-commits] [mlir] 83b582d - [mlir][Linalg] Properly propagate transform result in ScalarizeOp

Nicolas Vasilache llvmlistbot at llvm.org
Tue Dec 27 06:17:02 PST 2022


Author: Nicolas Vasilache
Date: 2022-12-27T06:16:55-08:00
New Revision: 83b582d51b742ad4a3e2b10e55058508b0e1ebc6

URL: https://github.com/llvm/llvm-project/commit/83b582d51b742ad4a3e2b10e55058508b0e1ebc6
DIFF: https://github.com/llvm/llvm-project/commit/83b582d51b742ad4a3e2b10e55058508b0e1ebc6.diff

LOG: [mlir][Linalg] Properly propagate transform result in ScalarizeOp

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/test/Dialect/Linalg/transform-op-scalarize.mlir
    mlir/test/Dialect/Linalg/transform-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 347c53085aa58..5660891d56529 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -67,14 +67,14 @@ DiagnosedSilenceableFailure
 transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
                                    SmallVectorImpl<Operation *> &results,
                                    transform::TransformState &state) {
-#define DOWNSCALE(trans) \
-    { \
-      FailureOr<LinalgOp> res = tryApply<trans>(target); \
-      if (succeeded(res)) { \
-        results.push_back(*res); \
-        return DiagnosedSilenceableFailure::success(); \
-      } \
-    }
+#define DOWNSCALE(trans)                                                       \
+  {                                                                            \
+    FailureOr<LinalgOp> res = tryApply<trans>(target);                         \
+    if (succeeded(res)) {                                                      \
+      results.push_back(*res);                                                 \
+      return DiagnosedSilenceableFailure::success();                           \
+    }                                                                          \
+  }
 
 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
@@ -986,6 +986,10 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
   if (failed(maybeTilingResult))
     return emitDefaultDefiniteFailure(target);
 
+  if (target->getNumResults())
+    rewriter.replaceOp(target, maybeTilingResult->replacements);
+  else
+    rewriter.eraseOp(target);
   results.append(maybeTilingResult->tiledOps);
   return DiagnosedSilenceableFailure::success();
 }

diff  --git a/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir b/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir
index 89c8d3265373b..fbf083c3d1ad8 100644
--- a/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir
@@ -5,8 +5,16 @@ func.func @scalarize(%arg0: tensor<24x12xf32>,
                      %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
   // The op is first tiled by 10 in the first dimension, which creates a
   // dynamic size, and then scalarized, which brings the dimension to static 1.
-  // CHECK: linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<1x12
+  // CHECK: %[[RES_LOOP_1:.*]] = scf.for {{.*}} -> (tensor<24x25xf32>)
+  // CHECK:   %[[RES_LOOP_2:.*]] = scf.for {{.*}} -> (tensor<?x25xf32>)
+  // CHECK:     %[[MM:.*]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<1x12
+  // CHECK:     %[[INS_2:.*]] = tensor.insert_slice %[[MM]] into %{{.*}} [1, 25] [1, 1] : tensor<1x25xf32> into tensor<?x25xf32>
+  // CHECK:     scf.yield %[[INS_2]] : tensor<?x25xf32>
+  // CHECK:   %[[INS_1:.*]] = tensor.insert_slice %[[RES_LOOP_2]] into %{{.*}}, 25] [1, 1] : tensor<?x25xf32> into tensor<24x25xf32>
+  // CHECK:   scf.yield %[[INS_1]] : tensor<24x25xf32>
   %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+
+  // CHECK: return %[[RES_LOOP_1]] : tensor<24x25xf32>
   func.return %0 : tensor<24x25xf32>
 }
 

diff  --git a/mlir/test/Dialect/Linalg/transform-ops.mlir b/mlir/test/Dialect/Linalg/transform-ops.mlir
index 898cce730e536..64cf3fbd04f90 100644
--- a/mlir/test/Dialect/Linalg/transform-ops.mlir
+++ b/mlir/test/Dialect/Linalg/transform-ops.mlir
@@ -8,7 +8,7 @@ transform.sequence failures(propagate) {
 
 //===----------------------------------------------------------------------===//
 // Check that operations are registered correctly through the extension
-// mechanism. Their syntax is generated and requries no additional testing since
+// mechanism. Their syntax is generated and requires no additional testing since
 // we test the generator.
 //===----------------------------------------------------------------------===//
 


        


More information about the Mlir-commits mailing list