[Mlir-commits] [mlir] 65eedce - [mlir] detensorize: don't accidentally convert function entry blocks

Alex Zinenko llvmlistbot at llvm.org
Mon Apr 24 20:34:02 PDT 2023


Author: Alex Zinenko
Date: 2023-04-25T03:33:54Z
New Revision: 65eedcebdc03052959508911417bac548009652a

URL: https://github.com/llvm/llvm-project/commit/65eedcebdc03052959508911417bac548009652a
DIFF: https://github.com/llvm/llvm-project/commit/65eedcebdc03052959508911417bac548009652a.diff

LOG: [mlir] detensorize: don't accidentally convert function entry blocks

In the Linalg detensorize pass, dialect conversion could accidentally
trigger signature conversion of the function entry block after inlining
the body of a Linalg generic into it. Such a conversion is not desirable
because it would break the internal validity of the function op, that is
futhermore not supposed to be detensorized at the boundary. Mitigate
this by creating a dummy (empty) entry block so Linalg operations are
never inlined into it and the conversion is never triggered.

Closes #62249.

Reviewed By: silvas

Differential Revision: https://reviews.llvm.org/D148983

Added: 
    mlir/test/Dialect/Linalg/detensorize_entry_block.mlir

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
    mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 5289ed6c7f519..9012a634a2417 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -60,7 +60,7 @@ bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
          });
 }
 
-/// A conversion patttern for detensoring `linalg.generic` ops.
+/// A conversion pattern for detensoring `linalg.generic` ops.
 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -69,7 +69,7 @@ class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
                   ConversionPatternRewriter &rewriter) const override {
     Block *originalBlock = op->getBlock();
 
-    // Gather some information about the op before inling its region.
+    // Gather some information about the op before inlining its region.
     Block *opEntryBlock = &*op.getRegion().begin();
     YieldOp yieldOp = dyn_cast<YieldOp>(op.getRegion().back().getTerminator());
 
@@ -476,6 +476,18 @@ struct LinalgDetensorize
     DenseSet<BlockArgument> blockArgsToDetensor;
     FunctionOpInterface funcOp = getOperation();
 
+    // Make sure the entry block of the function doesn't contain any Linalg ops.
+    // Otherwise, it may lead to the signature of the block being changed by the
+    // dialect conversion below, which would make the function op invalid
+    // because its type shouldn't change.
+    IRRewriter rewriter(funcOp->getContext());
+    Block *entryBlock = &funcOp.getFunctionBody().front();
+    Block *postEntryBlock =
+        rewriter.splitBlock(entryBlock, entryBlock->begin());
+    rewriter.setInsertionPointToStart(entryBlock);
+    auto branch =
+        rewriter.create<cf::BranchOp>(rewriter.getUnknownLoc(), postEntryBlock);
+
     if (aggressiveMode.getValue()) {
       AggressiveDetensoringModel costModel;
       costModel.compute(funcOp, typeConverter, opsToDetensor,
@@ -553,6 +565,11 @@ struct LinalgDetensorize
     if (failed(applyPatternsAndFoldGreedily(getOperation(),
                                             std::move(canonPatterns))))
       signalPassFailure();
+
+    // Get rid of the dummy entry block we created in the beginning to work
+    // around dialect conversion signature rewriting.
+    rewriter.eraseOp(branch);
+    rewriter.mergeBlocks(postEntryBlock, entryBlock);
   }
 };
 } // namespace

diff  --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
new file mode 100644
index 0000000000000..d1a89226fdb58
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline="builtin.module(func.func(linalg-detensorize))" | FileCheck %s
+
+#map = affine_map<() -> ()>
+func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
+  %0 = tensor.empty() : tensor<f32>
+  %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = []} ins(%arg0 : tensor<f32>) outs(%0 : tensor<f32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<f32>
+  cf.br ^bb1(%1 : tensor<f32>)
+^bb1(%2: tensor<f32>):  // pred: ^bb0
+  return %2 : tensor<f32>
+}
+
+// CHECK-LABEL: @main
+// CHECK-SAME:       (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
+// CHECK:   %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
+// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
+// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
+// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
+// CHECK:   return %[[ELEMENTS]] : tensor<f32>

diff  --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
index 455fcfe7b498e..6d8d5fe71fca5 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
@@ -44,8 +44,8 @@ func.func @main() -> () attributes {} {
 }
 
 // CHECK-LABEL: func @main
-// CHECK-NEXT:    arith.constant 0 : i32
-// CHECK-NEXT:    arith.constant 10
+// CHECK-DAG:    arith.constant 0 : i32
+// CHECK-DAG:    arith.constant 10
 // CHECK-NEXT:    cf.br ^[[bb1:.*]](%{{.*}} : i32)
 // CHECK-NEXT:  ^[[bb1]](%{{.*}}: i32)
 // CHECK-NEXT:    %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}}


        


More information about the Mlir-commits mailing list