[llvm-branch-commits] [mlir] [mlir][LLVM] `LLVMTypeConverter`: Tighten materialization checks (PR #116532)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Nov 22 23:45:36 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/116532

>From 4e4a5c81a1c45c8d4fbadacd67fa5439231e912e Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 23 Nov 2024 08:22:13 +0100
Subject: [PATCH 1/2] [mlir][Func] Delete `DecomposeCallGraphTypes.cpp`

---
 .../Func/Transforms/DecomposeCallGraphTypes.h |  34 -----
 .../Dialect/Func/Transforms/CMakeLists.txt    |   1 -
 .../Transforms/DecomposeCallGraphTypes.cpp    | 136 ------------------
 .../Func/Transforms/FuncConversions.cpp       |   8 +-
 .../Func/TestDecomposeCallGraphTypes.cpp      |   7 +-
 5 files changed, 8 insertions(+), 178 deletions(-)
 delete mode 100644 mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
 delete mode 100644 mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp

diff --git a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
deleted file mode 100644
index 1be406bf3adf92..00000000000000
--- a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
+++ /dev/null
@@ -1,34 +0,0 @@
-//===- DecomposeCallGraphTypes.h - CG type decompositions -------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Conversion patterns for decomposing types along call graph edges. That is,
-// decomposing types for calls, returns, and function args.
-//
-// TODO: Make this handle dialect-defined functions, calls, and returns.
-// Currently, the generic interfaces aren't sophisticated enough for the
-// types of mutations that we are doing here.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H
-#define MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H
-
-#include "mlir/Transforms/DialectConversion.h"
-#include <optional>
-
-namespace mlir {
-
-/// Populates the patterns needed to drive the conversion process for
-/// decomposing call graph types with the given `TypeConverter`.
-void populateDecomposeCallGraphTypesPatterns(MLIRContext *context,
-                                             const TypeConverter &typeConverter,
-                                             RewritePatternSet &patterns);
-
-} // namespace mlir
-
-#endif // MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H
diff --git a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
index f8fb1f436a95b1..6384d25ee70273 100644
--- a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
@@ -1,5 +1,4 @@
 add_mlir_dialect_library(MLIRFuncTransforms
-  DecomposeCallGraphTypes.cpp
   DuplicateFunctionElimination.cpp
   FuncConversions.cpp
   OneToNFuncConversions.cpp
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
deleted file mode 100644
index 03be00328bda33..00000000000000
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ /dev/null
@@ -1,136 +0,0 @@
-//===- DecomposeCallGraphTypes.cpp - CG type decomposition ----------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/BuiltinOps.h"
-
-using namespace mlir;
-using namespace mlir::func;
-
-//===----------------------------------------------------------------------===//
-// DecomposeCallGraphTypesForFuncArgs
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Expand function arguments according to the provided TypeConverter.
-struct DecomposeCallGraphTypesForFuncArgs
-    : public OpConversionPattern<func::FuncOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    auto functionType = op.getFunctionType();
-
-    // Convert function arguments using the provided TypeConverter.
-    TypeConverter::SignatureConversion conversion(functionType.getNumInputs());
-    for (const auto &argType : llvm::enumerate(functionType.getInputs())) {
-      SmallVector<Type, 2> decomposedTypes;
-      if (failed(typeConverter->convertType(argType.value(), decomposedTypes)))
-        return failure();
-      if (!decomposedTypes.empty())
-        conversion.addInputs(argType.index(), decomposedTypes);
-    }
-
-    // If the SignatureConversion doesn't apply, bail out.
-    if (failed(rewriter.convertRegionTypes(&op.getBody(), *getTypeConverter(),
-                                           &conversion)))
-      return failure();
-
-    // Update the signature of the function.
-    SmallVector<Type, 2> newResultTypes;
-    if (failed(typeConverter->convertTypes(functionType.getResults(),
-                                           newResultTypes)))
-      return failure();
-    rewriter.modifyOpInPlace(op, [&] {
-      op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
-                                          newResultTypes));
-    });
-    return success();
-  }
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// DecomposeCallGraphTypesForReturnOp
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Expand return operands according to the provided TypeConverter.
-struct DecomposeCallGraphTypesForReturnOp
-    : public OpConversionPattern<ReturnOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    SmallVector<Value, 2> newOperands;
-    for (ValueRange operand : adaptor.getOperands())
-      llvm::append_range(newOperands, operand);
-    rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
-    return success();
-  }
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// DecomposeCallGraphTypesForCallOp
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Expand call op operands and results according to the provided TypeConverter.
-struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-
-    // Create the operands list of the new `CallOp`.
-    SmallVector<Value, 2> newOperands;
-    for (ValueRange operand : adaptor.getOperands())
-      llvm::append_range(newOperands, operand);
-
-    // Create the new result types for the new `CallOp` and track the number of
-    // replacement types for each original op result.
-    SmallVector<Type, 2> newResultTypes;
-    SmallVector<unsigned> expandedResultSizes;
-    for (Type resultType : op.getResultTypes()) {
-      unsigned oldSize = newResultTypes.size();
-      if (failed(typeConverter->convertType(resultType, newResultTypes)))
-        return failure();
-      expandedResultSizes.push_back(newResultTypes.size() - oldSize);
-    }
-
-    CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
-                                               newResultTypes, newOperands);
-
-    // Build a replacement value for each result to replace its uses.
-    SmallVector<ValueRange> replacedValues;
-    replacedValues.reserve(op.getNumResults());
-    unsigned startIdx = 0;
-    for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
-      ValueRange repl =
-          newCallOp.getResults().slice(startIdx, expandedResultSizes[i]);
-      replacedValues.push_back(repl);
-      startIdx += expandedResultSizes[i];
-    }
-    rewriter.replaceOpWithMultiple(op, replacedValues);
-    return success();
-  }
-};
-} // namespace
-
-void mlir::populateDecomposeCallGraphTypesPatterns(
-    MLIRContext *context, const TypeConverter &typeConverter,
-    RewritePatternSet &patterns) {
-  patterns
-      .add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
-           DecomposeCallGraphTypesForReturnOp>(typeConverter, context);
-}
diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
index 9e7759bef6d8fd..a3638c8766a5c6 100644
--- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
@@ -124,12 +124,10 @@ class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
   using OpConversionPattern<ReturnOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
+  matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
-    // For a return, all operands go to the results of the parent, so
-    // rewrite them all.
-    rewriter.modifyOpInPlace(op,
-                             [&] { op->setOperands(adaptor.getOperands()); });
+    rewriter.replaceOpWithNewOp<ReturnOp>(op,
+                                          flattenValues(adaptor.getOperands()));
     return success();
   }
 };
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index de511c58ae6ee0..09c5b4b2a0ad50 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -9,7 +9,7 @@
 #include "TestDialect.h"
 #include "TestOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -142,7 +142,10 @@ struct TestDecomposeCallGraphTypes
     typeConverter.addArgumentMaterialization(buildMakeTupleOp);
     typeConverter.addTargetMaterialization(buildDecomposeTuple);
 
-    populateDecomposeCallGraphTypesPatterns(context, typeConverter, patterns);
+    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+        patterns, typeConverter);
+    populateReturnOpTypeConversionPattern(patterns, typeConverter);
+    populateCallOpTypeConversionPattern(patterns, typeConverter);
 
     if (failed(applyPartialConversion(module, target, std::move(patterns))))
       return signalPassFailure();

>From fe68c3c6702ae0de6549a6db3b014e6fb4dc898a Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 17 Nov 2024 09:00:45 +0100
Subject: [PATCH 2/2] [mlir][LLVM] `LLVMTypeConverter`: Tighten materialization
 checks

---
 .../Conversion/LLVMCommon/TypeConverter.cpp   | 32 ++++----
 .../MemRefToLLVM/type-conversion.mlir         | 57 ++++++++++++++
 mlir/test/lib/Dialect/LLVM/CMakeLists.txt     |  1 +
 mlir/test/lib/Dialect/LLVM/TestPatterns.cpp   | 77 +++++++++++++++++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |  2 +
 5 files changed, 154 insertions(+), 15 deletions(-)
 create mode 100644 mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
 create mode 100644 mlir/test/lib/Dialect/LLVM/TestPatterns.cpp

diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ce91424e7a577e..59b0f5c9b09bcd 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,6 +153,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                        type.isVarArg());
   });
 
+  // Helper function that checks if the given value range is a bare pointer.
+  auto isBarePointer = [](ValueRange values) {
+    return values.size() == 1 &&
+           isa<LLVM::LLVMPointerType>(values.front().getType());
+  };
+
   // Argument materializations convert from the new block argument types
   // (multiple SSA values that make up a memref descriptor) back to the
   // original block argument type. The dialect conversion framework will then
@@ -161,11 +167,10 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   addArgumentMaterialization([&](OpBuilder &builder,
                                  UnrankedMemRefType resultType,
                                  ValueRange inputs, Location loc) {
-    if (inputs.size() == 1) {
-      // Bare pointers are not supported for unranked memrefs because a
-      // memref descriptor cannot be built just from a bare pointer.
+    // Note: Bare pointers are not supported for unranked memrefs because a
+    // memref descriptor cannot be built just from a bare pointer.
+    if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
       return Value();
-    }
     Value desc =
         UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
     // An argument materialization must return a value of type
@@ -177,20 +182,17 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
                                  ValueRange inputs, Location loc) {
     Value desc;
-    if (inputs.size() == 1) {
-      // This is a bare pointer. We allow bare pointers only for function entry
-      // blocks.
-      BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
-      if (!barePtr)
-        return Value();
-      Block *block = barePtr.getOwner();
-      if (!block->isEntryBlock() ||
-          !isa<FunctionOpInterface>(block->getParentOp()))
-        return Value();
+    if (isBarePointer(inputs)) {
       desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
                                                inputs[0]);
-    } else {
+    } else if (TypeRange(inputs) ==
+               getMemRefDescriptorFields(resultType,
+                                         /*unpackAggregates=*/true)) {
       desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+    } else {
+      // The inputs are neither a bare pointer nor an unpacked memref
+      // descriptor. This materialization function cannot be used.
+      return Value();
     }
     // An argument materialization must return a value of type `resultType`,
     // so insert a cast from the memref descriptor type (!llvm.struct) to the
diff --git a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
new file mode 100644
index 00000000000000..0288aa11313c72
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file
+
+// Test the argument materializer for ranked MemRef types.
+
+//   CHECK-LABEL: func @construct_ranked_memref_descriptor(
+//         CHECK:   llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-COUNT-7:   llvm.insertvalue
+//         CHECK:   builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<5x4xf32>
+func.func @construct_ranked_memref_descriptor(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) {
+  %0 = "test.direct_replacement"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64) -> (memref<5x4xf32>)
+  "test.legal_op"(%0) : (memref<5x4xf32>) -> ()
+  return
+}
+
+// -----
+
+// The argument materializer for ranked MemRef types is called with incorrect
+// input types. Make sure that the materializer is skipped and we do not
+// generate invalid IR.
+
+// CHECK-LABEL: func @invalid_ranked_memref_descriptor(
+//       CHECK:   %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<5x4xf32>
+//       CHECK:   "test.legal_op"(%[[cast]])
+func.func @invalid_ranked_memref_descriptor(%arg0: i1) {
+  %0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<5x4xf32>)
+  "test.legal_op"(%0) : (memref<5x4xf32>) -> ()
+  return
+}
+
+// -----
+
+// Test the argument materializer for unranked MemRef types.
+
+//   CHECK-LABEL: func @construct_unranked_memref_descriptor(
+//         CHECK:   llvm.mlir.undef : !llvm.struct<(i64, ptr)>
+// CHECK-COUNT-2:   llvm.insertvalue
+//         CHECK:   builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(i64, ptr)> to memref<*xf32>
+func.func @construct_unranked_memref_descriptor(%arg0: i64, %arg1: !llvm.ptr) {
+  %0 = "test.direct_replacement"(%arg0, %arg1) : (i64, !llvm.ptr) -> (memref<*xf32>)
+  "test.legal_op"(%0) : (memref<*xf32>) -> ()
+  return
+}
+
+// -----
+
+// The argument materializer for unranked MemRef types is called with incorrect
+// input types. Make sure that the materializer is skipped and we do not
+// generate invalid IR.
+
+// CHECK-LABEL: func @invalid_unranked_memref_descriptor(
+//       CHECK:   %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<*xf32>
+//       CHECK:   "test.legal_op"(%[[cast]])
+func.func @invalid_unranked_memref_descriptor(%arg0: i1) {
+  %0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<*xf32>)
+  "test.legal_op"(%0) : (memref<*xf32>) -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/LLVM/CMakeLists.txt b/mlir/test/lib/Dialect/LLVM/CMakeLists.txt
index 734757ce79da37..6a2f0ba2756d43 100644
--- a/mlir/test/lib/Dialect/LLVM/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/LLVM/CMakeLists.txt
@@ -1,6 +1,7 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRLLVMTestPasses
   TestLowerToLLVM.cpp
+  TestPatterns.cpp
 
   EXCLUDE_FROM_LIBMLIR
 
diff --git a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
new file mode 100644
index 00000000000000..ab02866970b1d5
--- /dev/null
+++ b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
@@ -0,0 +1,77 @@
+//===- TestPatterns.cpp - LLVM dialect test patterns ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Replace this op (which is expected to have 1 result) with the operands.
+struct TestDirectReplacementOp : public ConversionPattern {
+  TestDirectReplacementOp(MLIRContext *ctx, const TypeConverter &converter)
+      : ConversionPattern(converter, "test.direct_replacement", 1, ctx) {}
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    if (op->getNumResults() != 1)
+      return failure();
+    rewriter.replaceOpWithMultiple(op, {operands});
+    return success();
+  }
+};
+
+struct TestLLVMLegalizePatternsPass
+    : public PassWrapper<TestLLVMLegalizePatternsPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLLVMLegalizePatternsPass)
+
+  StringRef getArgument() const final { return "test-llvm-legalize-patterns"; }
+  StringRef getDescription() const final {
+    return "Run LLVM dialect legalization patterns";
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<LLVM::LLVMDialect>();
+  }
+
+  void runOnOperation() override {
+    MLIRContext *ctx = &getContext();
+    LLVMTypeConverter converter(ctx);
+    mlir::RewritePatternSet patterns(ctx);
+    patterns.add<TestDirectReplacementOp>(ctx, converter);
+
+    // Define the conversion target used for the test.
+    ConversionTarget target(*ctx);
+    target.addLegalOp(OperationName("test.legal_op", ctx));
+
+    // Handle a partial conversion.
+    DenseSet<Operation *> unlegalizedOps;
+    ConversionConfig config;
+    config.unlegalizedOps = &unlegalizedOps;
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns), config)))
+      getOperation()->emitError() << "applyPartialConversion failed";
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// PassRegistration
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace test {
+void registerTestLLVMLegalizePatternsPass() {
+  PassRegistration<TestLLVMLegalizePatternsPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 002c3900056dee..94bc67a1e96093 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -113,6 +113,7 @@ void registerTestLinalgRankReduceContractionOps();
 void registerTestLinalgTransforms();
 void registerTestLivenessAnalysisPass();
 void registerTestLivenessPass();
+void registerTestLLVMLegalizePatternsPass();
 void registerTestLoopFusion();
 void registerTestLoopMappingPass();
 void registerTestLoopUnrollingPass();
@@ -250,6 +251,7 @@ void registerTestPasses() {
   mlir::test::registerTestLinalgTransforms();
   mlir::test::registerTestLivenessAnalysisPass();
   mlir::test::registerTestLivenessPass();
+  mlir::test::registerTestLLVMLegalizePatternsPass();
   mlir::test::registerTestLoopFusion();
   mlir::test::registerTestLoopMappingPass();
   mlir::test::registerTestLoopUnrollingPass();



More information about the llvm-branch-commits mailing list