[Mlir-commits] [mlir] a0ef12c - [mlir][LLVM] `LLVMTypeConverter`: Tighten materialization checks (#116532)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Nov 23 19:20:13 PST 2024
Author: Matthias Springer
Date: 2024-11-24T12:20:09+09:00
New Revision: a0ef12c64284abf59bc092b2535cce1247d5f9a4
URL: https://github.com/llvm/llvm-project/commit/a0ef12c64284abf59bc092b2535cce1247d5f9a4
DIFF: https://github.com/llvm/llvm-project/commit/a0ef12c64284abf59bc092b2535cce1247d5f9a4.diff
LOG: [mlir][LLVM] `LLVMTypeConverter`: Tighten materialization checks (#116532)
This commit adds extra checks to the MemRef argument materializations in
the LLVM type converter. These materializations construct a
`MemRefType`/`UnrankedMemRefType` from the unpacked elements of a MemRef
descriptor or from a bare pointer.
The extra checks ensure that the inputs to the materialization function
are correct. It is possible that a user added extra type conversion
rules that convert MemRef types in a different way and the extra checks
ensure that we construct a MemRef descriptor only if the inputs are what
we expect.
This commit also drops a check around bare pointer materializations:
```
// This is a bare pointer. We allow bare pointers only for function entry
// blocks.
```
This check should not be part of the materialization function. Whether a
MemRef block argument is converted into a MemRef descriptor or a bare
pointer is decided in the lowering pattern. At the point of time when
materialization functions are executed, we already made that decision
and we should just materialize regardless of the input format.
Added:
mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
Modified:
mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
mlir/test/lib/Dialect/LLVM/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ce91424e7a577e..59b0f5c9b09bcd 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,6 +153,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
type.isVarArg());
});
+ // Helper function that checks if the given value range is a bare pointer.
+ auto isBarePointer = [](ValueRange values) {
+ return values.size() == 1 &&
+ isa<LLVM::LLVMPointerType>(values.front().getType());
+ };
+
// Argument materializations convert from the new block argument types
// (multiple SSA values that make up a memref descriptor) back to the
// original block argument type. The dialect conversion framework will then
@@ -161,11 +167,10 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
addArgumentMaterialization([&](OpBuilder &builder,
UnrankedMemRefType resultType,
ValueRange inputs, Location loc) {
- if (inputs.size() == 1) {
- // Bare pointers are not supported for unranked memrefs because a
- // memref descriptor cannot be built just from a bare pointer.
+ // Note: Bare pointers are not supported for unranked memrefs because a
+ // memref descriptor cannot be built just from a bare pointer.
+ if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
return Value();
- }
Value desc =
UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
// An argument materialization must return a value of type
@@ -177,20 +182,17 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
ValueRange inputs, Location loc) {
Value desc;
- if (inputs.size() == 1) {
- // This is a bare pointer. We allow bare pointers only for function entry
- // blocks.
- BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
- if (!barePtr)
- return Value();
- Block *block = barePtr.getOwner();
- if (!block->isEntryBlock() ||
- !isa<FunctionOpInterface>(block->getParentOp()))
- return Value();
+ if (isBarePointer(inputs)) {
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
inputs[0]);
- } else {
+ } else if (TypeRange(inputs) ==
+ getMemRefDescriptorFields(resultType,
+ /*unpackAggregates=*/true)) {
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+ } else {
+ // The inputs are neither a bare pointer nor an unpacked memref
+ // descriptor. This materialization function cannot be used.
+ return Value();
}
// An argument materialization must return a value of type `resultType`,
// so insert a cast from the memref descriptor type (!llvm.struct) to the
diff --git a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
new file mode 100644
index 00000000000000..0288aa11313c72
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file
+
+// Test the argument materializer for ranked MemRef types.
+
+// CHECK-LABEL: func @construct_ranked_memref_descriptor(
+// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-COUNT-7: llvm.insertvalue
+// CHECK: builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<5x4xf32>
+func.func @construct_ranked_memref_descriptor(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) {
+ %0 = "test.direct_replacement"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64) -> (memref<5x4xf32>)
+ "test.legal_op"(%0) : (memref<5x4xf32>) -> ()
+ return
+}
+
+// -----
+
+// The argument materializer for ranked MemRef types is called with incorrect
+// input types. Make sure that the materializer is skipped and we do not
+// generate invalid IR.
+
+// CHECK-LABEL: func @invalid_ranked_memref_descriptor(
+// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<5x4xf32>
+// CHECK: "test.legal_op"(%[[cast]])
+func.func @invalid_ranked_memref_descriptor(%arg0: i1) {
+ %0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<5x4xf32>)
+ "test.legal_op"(%0) : (memref<5x4xf32>) -> ()
+ return
+}
+
+// -----
+
+// Test the argument materializer for unranked MemRef types.
+
+// CHECK-LABEL: func @construct_unranked_memref_descriptor(
+// CHECK: llvm.mlir.undef : !llvm.struct<(i64, ptr)>
+// CHECK-COUNT-2: llvm.insertvalue
+// CHECK: builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(i64, ptr)> to memref<*xf32>
+func.func @construct_unranked_memref_descriptor(%arg0: i64, %arg1: !llvm.ptr) {
+ %0 = "test.direct_replacement"(%arg0, %arg1) : (i64, !llvm.ptr) -> (memref<*xf32>)
+ "test.legal_op"(%0) : (memref<*xf32>) -> ()
+ return
+}
+
+// -----
+
+// The argument materializer for unranked MemRef types is called with incorrect
+// input types. Make sure that the materializer is skipped and we do not
+// generate invalid IR.
+
+// CHECK-LABEL: func @invalid_unranked_memref_descriptor(
+// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<*xf32>
+// CHECK: "test.legal_op"(%[[cast]])
+func.func @invalid_unranked_memref_descriptor(%arg0: i1) {
+ %0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<*xf32>)
+ "test.legal_op"(%0) : (memref<*xf32>) -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/LLVM/CMakeLists.txt b/mlir/test/lib/Dialect/LLVM/CMakeLists.txt
index 734757ce79da37..6a2f0ba2756d43 100644
--- a/mlir/test/lib/Dialect/LLVM/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/LLVM/CMakeLists.txt
@@ -1,6 +1,7 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRLLVMTestPasses
TestLowerToLLVM.cpp
+ TestPatterns.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
new file mode 100644
index 00000000000000..ab02866970b1d5
--- /dev/null
+++ b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
@@ -0,0 +1,77 @@
+//===- TestPatterns.cpp - LLVM dialect test patterns ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Replace this op (which is expected to have 1 result) with the operands.
+struct TestDirectReplacementOp : public ConversionPattern {
+ TestDirectReplacementOp(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter, "test.direct_replacement", 1, ctx) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ if (op->getNumResults() != 1)
+ return failure();
+ rewriter.replaceOpWithMultiple(op, {operands});
+ return success();
+ }
+};
+
+struct TestLLVMLegalizePatternsPass
+ : public PassWrapper<TestLLVMLegalizePatternsPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLLVMLegalizePatternsPass)
+
+ StringRef getArgument() const final { return "test-llvm-legalize-patterns"; }
+ StringRef getDescription() const final {
+ return "Run LLVM dialect legalization patterns";
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<LLVM::LLVMDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext *ctx = &getContext();
+ LLVMTypeConverter converter(ctx);
+ mlir::RewritePatternSet patterns(ctx);
+ patterns.add<TestDirectReplacementOp>(ctx, converter);
+
+ // Define the conversion target used for the test.
+ ConversionTarget target(*ctx);
+ target.addLegalOp(OperationName("test.legal_op", ctx));
+
+ // Handle a partial conversion.
+ DenseSet<Operation *> unlegalizedOps;
+ ConversionConfig config;
+ config.unlegalizedOps = &unlegalizedOps;
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns), config)))
+ getOperation()->emitError() << "applyPartialConversion failed";
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// PassRegistration
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace test {
+void registerTestLLVMLegalizePatternsPass() {
+ PassRegistration<TestLLVMLegalizePatternsPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 002c3900056dee..94bc67a1e96093 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -113,6 +113,7 @@ void registerTestLinalgRankReduceContractionOps();
void registerTestLinalgTransforms();
void registerTestLivenessAnalysisPass();
void registerTestLivenessPass();
+void registerTestLLVMLegalizePatternsPass();
void registerTestLoopFusion();
void registerTestLoopMappingPass();
void registerTestLoopUnrollingPass();
@@ -250,6 +251,7 @@ void registerTestPasses() {
mlir::test::registerTestLinalgTransforms();
mlir::test::registerTestLivenessAnalysisPass();
mlir::test::registerTestLivenessPass();
+ mlir::test::registerTestLLVMLegalizePatternsPass();
mlir::test::registerTestLoopFusion();
mlir::test::registerTestLoopMappingPass();
mlir::test::registerTestLoopUnrollingPass();
More information about the Mlir-commits
mailing list