[Mlir-commits] [mlir] 65eedce - [mlir] detensorize: don't accidentally convert function entry blocks
Alex Zinenko
llvmlistbot at llvm.org
Mon Apr 24 20:34:02 PDT 2023
Author: Alex Zinenko
Date: 2023-04-25T03:33:54Z
New Revision: 65eedcebdc03052959508911417bac548009652a
URL: https://github.com/llvm/llvm-project/commit/65eedcebdc03052959508911417bac548009652a
DIFF: https://github.com/llvm/llvm-project/commit/65eedcebdc03052959508911417bac548009652a.diff
LOG: [mlir] detensorize: don't accidentally convert function entry blocks
In the Linalg detensorize pass, dialect conversion could accidentally
trigger signature conversion of the function entry block after inlining
the body of a Linalg generic into it. Such a conversion is not desirable
because it would break the internal validity of the function op, that is
futhermore not supposed to be detensorized at the boundary. Mitigate
this by creating a dummy (empty) entry block so Linalg operations are
never inlined into it and the conversion is never triggered.
Closes #62249.
Reviewed By: silvas
Differential Revision: https://reviews.llvm.org/D148983
Added:
mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
Modified:
mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 5289ed6c7f519..9012a634a2417 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -60,7 +60,7 @@ bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
});
}
-/// A conversion patttern for detensoring `linalg.generic` ops.
+/// A conversion pattern for detensoring `linalg.generic` ops.
class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -69,7 +69,7 @@ class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
ConversionPatternRewriter &rewriter) const override {
Block *originalBlock = op->getBlock();
- // Gather some information about the op before inling its region.
+ // Gather some information about the op before inlining its region.
Block *opEntryBlock = &*op.getRegion().begin();
YieldOp yieldOp = dyn_cast<YieldOp>(op.getRegion().back().getTerminator());
@@ -476,6 +476,18 @@ struct LinalgDetensorize
DenseSet<BlockArgument> blockArgsToDetensor;
FunctionOpInterface funcOp = getOperation();
+ // Make sure the entry block of the function doesn't contain any Linalg ops.
+ // Otherwise, it may lead to the signature of the block being changed by the
+ // dialect conversion below, which would make the function op invalid
+ // because its type shouldn't change.
+ IRRewriter rewriter(funcOp->getContext());
+ Block *entryBlock = &funcOp.getFunctionBody().front();
+ Block *postEntryBlock =
+ rewriter.splitBlock(entryBlock, entryBlock->begin());
+ rewriter.setInsertionPointToStart(entryBlock);
+ auto branch =
+ rewriter.create<cf::BranchOp>(rewriter.getUnknownLoc(), postEntryBlock);
+
if (aggressiveMode.getValue()) {
AggressiveDetensoringModel costModel;
costModel.compute(funcOp, typeConverter, opsToDetensor,
@@ -553,6 +565,11 @@ struct LinalgDetensorize
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(canonPatterns))))
signalPassFailure();
+
+ // Get rid of the dummy entry block we created in the beginning to work
+ // around dialect conversion signature rewriting.
+ rewriter.eraseOp(branch);
+ rewriter.mergeBlocks(postEntryBlock, entryBlock);
}
};
} // namespace
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
new file mode 100644
index 0000000000000..d1a89226fdb58
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline="builtin.module(func.func(linalg-detensorize))" | FileCheck %s
+
+#map = affine_map<() -> ()>
+func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
+ %0 = tensor.empty() : tensor<f32>
+ %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = []} ins(%arg0 : tensor<f32>) outs(%0 : tensor<f32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<f32>
+ cf.br ^bb1(%1 : tensor<f32>)
+^bb1(%2: tensor<f32>): // pred: ^bb0
+ return %2 : tensor<f32>
+}
+
+// CHECK-LABEL: @main
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
+// CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
+// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
+// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
+// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
+// CHECK: return %[[ELEMENTS]] : tensor<f32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
index 455fcfe7b498e..6d8d5fe71fca5 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
@@ -44,8 +44,8 @@ func.func @main() -> () attributes {} {
}
// CHECK-LABEL: func @main
-// CHECK-NEXT: arith.constant 0 : i32
-// CHECK-NEXT: arith.constant 10
+// CHECK-DAG: arith.constant 0 : i32
+// CHECK-DAG: arith.constant 10
// CHECK-NEXT: cf.br ^[[bb1:.*]](%{{.*}} : i32)
// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32)
// CHECK-NEXT: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}}
More information about the Mlir-commits
mailing list