[Mlir-commits] [mlir] [mlir][LLVM] `LLVMTypeConverter`: Tighten materialization checks (PR #116532)
Matthias Springer
llvmlistbot at llvm.org
Sat Nov 23 19:10:38 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/116532
>From 1511e509b5c930bfa2f565948141ebe577467c0b Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 17 Nov 2024 09:00:45 +0100
Subject: [PATCH] [mlir][LLVM] `LLVMTypeConverter`: Tighten materialization
checks
---
.../Conversion/LLVMCommon/TypeConverter.cpp | 32 ++++----
.../MemRefToLLVM/type-conversion.mlir | 57 ++++++++++++++
mlir/test/lib/Dialect/LLVM/CMakeLists.txt | 1 +
mlir/test/lib/Dialect/LLVM/TestPatterns.cpp | 77 +++++++++++++++++++
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
5 files changed, 154 insertions(+), 15 deletions(-)
create mode 100644 mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
create mode 100644 mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ce91424e7a577e..59b0f5c9b09bcd 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,6 +153,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
type.isVarArg());
});
+ // Helper function that checks if the given value range is a bare pointer.
+ auto isBarePointer = [](ValueRange values) {
+ return values.size() == 1 &&
+ isa<LLVM::LLVMPointerType>(values.front().getType());
+ };
+
// Argument materializations convert from the new block argument types
// (multiple SSA values that make up a memref descriptor) back to the
// original block argument type. The dialect conversion framework will then
@@ -161,11 +167,10 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
addArgumentMaterialization([&](OpBuilder &builder,
UnrankedMemRefType resultType,
ValueRange inputs, Location loc) {
- if (inputs.size() == 1) {
- // Bare pointers are not supported for unranked memrefs because a
- // memref descriptor cannot be built just from a bare pointer.
+ // Note: Bare pointers are not supported for unranked memrefs because a
+ // memref descriptor cannot be built just from a bare pointer.
+ if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
return Value();
- }
Value desc =
UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
// An argument materialization must return a value of type
@@ -177,20 +182,17 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
ValueRange inputs, Location loc) {
Value desc;
- if (inputs.size() == 1) {
- // This is a bare pointer. We allow bare pointers only for function entry
- // blocks.
- BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
- if (!barePtr)
- return Value();
- Block *block = barePtr.getOwner();
- if (!block->isEntryBlock() ||
- !isa<FunctionOpInterface>(block->getParentOp()))
- return Value();
+ if (isBarePointer(inputs)) {
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
inputs[0]);
- } else {
+ } else if (TypeRange(inputs) ==
+ getMemRefDescriptorFields(resultType,
+ /*unpackAggregates=*/true)) {
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+ } else {
+ // The inputs are neither a bare pointer nor an unpacked memref
+ // descriptor. This materialization function cannot be used.
+ return Value();
}
// An argument materialization must return a value of type `resultType`,
// so insert a cast from the memref descriptor type (!llvm.struct) to the
diff --git a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
new file mode 100644
index 00000000000000..0288aa11313c72
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file
+
+// Test the argument materializer for ranked MemRef types.
+
+// CHECK-LABEL: func @construct_ranked_memref_descriptor(
+// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-COUNT-7: llvm.insertvalue
+// CHECK: builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<5x4xf32>
+func.func @construct_ranked_memref_descriptor(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) {
+ %0 = "test.direct_replacement"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64) -> (memref<5x4xf32>)
+ "test.legal_op"(%0) : (memref<5x4xf32>) -> ()
+ return
+}
+
+// -----
+
+// The argument materializer for ranked MemRef types is called with incorrect
+// input types. Make sure that the materializer is skipped and we do not
+// generate invalid IR.
+
+// CHECK-LABEL: func @invalid_ranked_memref_descriptor(
+// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<5x4xf32>
+// CHECK: "test.legal_op"(%[[cast]])
+func.func @invalid_ranked_memref_descriptor(%arg0: i1) {
+ %0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<5x4xf32>)
+ "test.legal_op"(%0) : (memref<5x4xf32>) -> ()
+ return
+}
+
+// -----
+
+// Test the argument materializer for unranked MemRef types.
+
+// CHECK-LABEL: func @construct_unranked_memref_descriptor(
+// CHECK: llvm.mlir.undef : !llvm.struct<(i64, ptr)>
+// CHECK-COUNT-2: llvm.insertvalue
+// CHECK: builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(i64, ptr)> to memref<*xf32>
+func.func @construct_unranked_memref_descriptor(%arg0: i64, %arg1: !llvm.ptr) {
+ %0 = "test.direct_replacement"(%arg0, %arg1) : (i64, !llvm.ptr) -> (memref<*xf32>)
+ "test.legal_op"(%0) : (memref<*xf32>) -> ()
+ return
+}
+
+// -----
+
+// The argument materializer for unranked MemRef types is called with incorrect
+// input types. Make sure that the materializer is skipped and we do not
+// generate invalid IR.
+
+// CHECK-LABEL: func @invalid_unranked_memref_descriptor(
+// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<*xf32>
+// CHECK: "test.legal_op"(%[[cast]])
+func.func @invalid_unranked_memref_descriptor(%arg0: i1) {
+ %0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<*xf32>)
+ "test.legal_op"(%0) : (memref<*xf32>) -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/LLVM/CMakeLists.txt b/mlir/test/lib/Dialect/LLVM/CMakeLists.txt
index 734757ce79da37..6a2f0ba2756d43 100644
--- a/mlir/test/lib/Dialect/LLVM/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/LLVM/CMakeLists.txt
@@ -1,6 +1,7 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRLLVMTestPasses
TestLowerToLLVM.cpp
+ TestPatterns.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
new file mode 100644
index 00000000000000..ab02866970b1d5
--- /dev/null
+++ b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
@@ -0,0 +1,77 @@
+//===- TestPatterns.cpp - LLVM dialect test patterns ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Replace this op (which is expected to have 1 result) with the operands.
+struct TestDirectReplacementOp : public ConversionPattern {
+ TestDirectReplacementOp(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter, "test.direct_replacement", 1, ctx) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ if (op->getNumResults() != 1)
+ return failure();
+ rewriter.replaceOpWithMultiple(op, {operands});
+ return success();
+ }
+};
+
+struct TestLLVMLegalizePatternsPass
+ : public PassWrapper<TestLLVMLegalizePatternsPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLLVMLegalizePatternsPass)
+
+ StringRef getArgument() const final { return "test-llvm-legalize-patterns"; }
+ StringRef getDescription() const final {
+ return "Run LLVM dialect legalization patterns";
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<LLVM::LLVMDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext *ctx = &getContext();
+ LLVMTypeConverter converter(ctx);
+ mlir::RewritePatternSet patterns(ctx);
+ patterns.add<TestDirectReplacementOp>(ctx, converter);
+
+ // Define the conversion target used for the test.
+ ConversionTarget target(*ctx);
+ target.addLegalOp(OperationName("test.legal_op", ctx));
+
+ // Handle a partial conversion.
+ DenseSet<Operation *> unlegalizedOps;
+ ConversionConfig config;
+ config.unlegalizedOps = &unlegalizedOps;
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns), config)))
+ getOperation()->emitError() << "applyPartialConversion failed";
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// PassRegistration
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace test {
+void registerTestLLVMLegalizePatternsPass() {
+ PassRegistration<TestLLVMLegalizePatternsPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 002c3900056dee..94bc67a1e96093 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -113,6 +113,7 @@ void registerTestLinalgRankReduceContractionOps();
void registerTestLinalgTransforms();
void registerTestLivenessAnalysisPass();
void registerTestLivenessPass();
+void registerTestLLVMLegalizePatternsPass();
void registerTestLoopFusion();
void registerTestLoopMappingPass();
void registerTestLoopUnrollingPass();
@@ -250,6 +251,7 @@ void registerTestPasses() {
mlir::test::registerTestLinalgTransforms();
mlir::test::registerTestLivenessAnalysisPass();
mlir::test::registerTestLivenessPass();
+ mlir::test::registerTestLLVMLegalizePatternsPass();
mlir::test::registerTestLoopFusion();
mlir::test::registerTestLoopMappingPass();
mlir::test::registerTestLoopUnrollingPass();
More information about the Mlir-commits
mailing list