[Mlir-commits] [mlir] [mlir][draft] Support 1:N dialect conversion (PR #112141)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Oct 13 07:48:34 PDT 2024


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff 9f24c145494ee238e65e25205a4dcb4451f009ae 7ec251bc0e69b4611d5acf8884be39d5461eb17b --extensions h,cpp -- mlir/include/mlir/Conversion/LLVMCommon/Pattern.h mlir/include/mlir/Transforms/DialectConversion.h mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp mlir/lib/Transforms/Utils/DialectConversion.cpp mlir/test/lib/Dialect/Test/TestPatterns.cpp mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp mlir/unittests/ExecutionEngine/Invoke.cpp
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 5285074ec6..f6cb99cfa9 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,55 +153,54 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                        type.isVarArg());
   });
 
-/*
-  // Argument materializations convert from the new block argument types
-  // (multiple SSA values that make up a memref descriptor) back to the
-  // original block argument type. The dialect conversion framework will then
-  // insert a target materialization from the original block argument type to
-  // a legal type.
-  addArgumentMaterialization(
-      [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
-          Location loc) -> std::optional<Value> {
-        if (inputs.size() == 1) {
-          // Bare pointers are not supported for unranked memrefs because a
-          // memref descriptor cannot be built just from a bare pointer.
+  /*
+    // Argument materializations convert from the new block argument types
+    // (multiple SSA values that make up a memref descriptor) back to the
+    // original block argument type. The dialect conversion framework will then
+    // insert a target materialization from the original block argument type to
+    // a legal type.
+    addArgumentMaterialization(
+        [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange
+    inputs, Location loc) -> std::optional<Value> { if (inputs.size() == 1) {
+            // Bare pointers are not supported for unranked memrefs because a
+            // memref descriptor cannot be built just from a bare pointer.
+            return std::nullopt;
+          }
+          Value desc = UnrankedMemRefDescriptor::pack(builder, loc, *this,
+                                                      resultType, inputs);
+          // An argument materialization must return a value of type
+          // `resultType`, so insert a cast from the memref descriptor type
+          // (!llvm.struct) to the original memref type.
+          return builder.create<UnrealizedConversionCastOp>(loc, resultType,
+    desc) .getResult(0);
+        });
+    addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
+                                   ValueRange inputs,
+                                   Location loc) -> std::optional<Value> {
+      Value desc;
+      if (inputs.size() == 1) {
+        // This is a bare pointer. We allow bare pointers only for function
+    entry
+        // blocks.
+        BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
+        if (!barePtr)
           return std::nullopt;
-        }
-        Value desc = UnrankedMemRefDescriptor::pack(builder, loc, *this,
-                                                    resultType, inputs);
-        // An argument materialization must return a value of type
-        // `resultType`, so insert a cast from the memref descriptor type
-        // (!llvm.struct) to the original memref type.
-        return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
-            .getResult(0);
-      });
-  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
-                                 ValueRange inputs,
-                                 Location loc) -> std::optional<Value> {
-    Value desc;
-    if (inputs.size() == 1) {
-      // This is a bare pointer. We allow bare pointers only for function entry
-      // blocks.
-      BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
-      if (!barePtr)
-        return std::nullopt;
-      Block *block = barePtr.getOwner();
-      if (!block->isEntryBlock() ||
-          !isa<FunctionOpInterface>(block->getParentOp()))
-        return std::nullopt;
-      desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
-                                               inputs[0]);
-    } else {
-      desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
-    }
-    // An argument materialization must return a value of type `resultType`,
-    // so insert a cast from the memref descriptor type (!llvm.struct) to the
-    // original memref type.
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
-        .getResult(0);
-  });
-
-*/
+        Block *block = barePtr.getOwner();
+        if (!block->isEntryBlock() ||
+            !isa<FunctionOpInterface>(block->getParentOp()))
+          return std::nullopt;
+        desc = MemRefDescriptor::fromStaticShape(builder, loc, *this,
+    resultType, inputs[0]); } else { desc = MemRefDescriptor::pack(builder, loc,
+    *this, resultType, inputs);
+      }
+      // An argument materialization must return a value of type `resultType`,
+      // so insert a cast from the memref descriptor type (!llvm.struct) to the
+      // original memref type.
+      return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
+          .getResult(0);
+    });
+
+  */
   // Add generic source and target materializations to handle cases where
   // non-LLVM types persist after an LLVM conversion.
   addSourceMaterialization([&](OpBuilder &builder, Type resultType,
@@ -211,14 +210,14 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
         .getResult(0);
   });
   addTargetMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs,
-                               Location loc, Type originalType) -> std::optional<Value> {
+                               ValueRange inputs, Location loc,
+                               Type originalType) -> std::optional<Value> {
     return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
         .getResult(0);
   });
   addTargetMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs,
-                               Location loc, Type originalType) -> std::optional<Value> {
+                               ValueRange inputs, Location loc,
+                               Type originalType) -> std::optional<Value> {
     llvm::errs() << "TARGET MAT: -> " << resultType << "\n";
     if (!originalType) {
       llvm::errs() << " -- no orig\n";
@@ -228,8 +227,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
       assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type");
       if (inputs.size() == 1) {
         Value input = inputs.front();
-        if (auto castOp =input.getDefiningOp<UnrealizedConversionCastOp>()) {
-          if (castOp.getInputs().size() == 1 && isa<LLVM::LLVMPointerType>(castOp.getInputs()[0].getType())) {
+        if (auto castOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
+          if (castOp.getInputs().size() == 1 &&
+              isa<LLVM::LLVMPointerType>(castOp.getInputs()[0].getType())) {
             input = castOp.getInputs()[0];
           }
         }
@@ -243,23 +243,23 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
             !isa<FunctionOpInterface>(block->getParentOp()))
           return std::nullopt;
         // Bare ptr
-        return MemRefDescriptor::fromStaticShape(builder, loc, *this, memrefType,
-                                                 input);
+        return MemRefDescriptor::fromStaticShape(builder, loc, *this,
+                                                 memrefType, input);
       }
       return MemRefDescriptor::pack(builder, loc, *this, memrefType, inputs);
     }
     if (auto memrefType = dyn_cast<UnrankedMemRefType>(originalType)) {
       assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type");
       if (inputs.size() == 1) {
-          // Bare pointers are not supported for unranked memrefs because a
-          // memref descriptor cannot be built just from a bare pointer.
-          return std::nullopt;
+        // Bare pointers are not supported for unranked memrefs because a
+        // memref descriptor cannot be built just from a bare pointer.
+        return std::nullopt;
       }
-      return UnrankedMemRefDescriptor::pack(builder, loc, *this,
-                                                    memrefType, inputs);
+      return UnrankedMemRefDescriptor::pack(builder, loc, *this, memrefType,
+                                            inputs);
     }
 
-      return std::nullopt;
+    return std::nullopt;
   });
 
   // Integer memory spaces map to themselves.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index efc2d24ddb..2b78a09b87 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1275,7 +1275,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     SmallVector<Value, 1> vals = mapping.lookupOrDefault(operand);
     ValueRange castValues = buildUnresolvedMaterialization(
         MaterializationKind::Target, computeInsertPoint(vals), operandLoc,
-        /*inputs=*/vals, /*outputTypes=*/legalTypes, /*originalType=*/origType, currentTypeConverter);
+        /*inputs=*/vals, /*outputTypes=*/legalTypes, /*originalType=*/origType,
+        currentTypeConverter);
 
     mapping.mapMaterialization(vals, castValues);
     remapped.push_back(castValues);
@@ -1430,7 +1431,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 /// of input operands.
 ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-    ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter) {
+    ValueRange inputs, TypeRange outputTypes, Type originalType,
+    const TypeConverter *converter) {
   // Avoid materializing an unnecessary cast.
   if (TypeRange(inputs) == outputTypes)
     return inputs;
@@ -1441,7 +1443,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
-  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind, originalType);
+  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
+                                                  originalType);
   return convertOp.getResults();
 }
 
@@ -1495,7 +1498,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
       repl = buildUnresolvedMaterialization(
           MaterializationKind::Source, computeInsertPoint(result),
           result.getLoc(), /*inputs=*/ValueRange(),
-          /*outputTypes=*/result.getType(), /*originalType=*/Type(), currentTypeConverter);
+          /*outputTypes=*/result.getType(), /*originalType=*/Type(),
+          currentTypeConverter);
     } else {
       // Make sure that the user does not mess with unresolved materializations
       // that were inserted by the conversion driver. We keep track of these
@@ -2735,8 +2739,8 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
       Value castValue = rewriterImpl.buildUnresolvedMaterialization(
           MaterializationKind::Source, computeInsertPoint(newValues),
           originalValue.getLoc(),
-          /*inputs=*/newValues, /*outputTypes=*/originalValue.getType(), /*originalType=*/Type(),
-          converter)[0];
+          /*inputs=*/newValues, /*outputTypes=*/originalValue.getType(),
+          /*originalType=*/Type(), converter)[0];
       rewriterImpl.mapping.mapMaterialization(newValues, {castValue});
       llvm::append_range(inverseMapping[castValue], newValues);
     }
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 291f6af84b..4465cfa4c1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -994,7 +994,8 @@ struct TestUpdateConsumerType : public ConversionPattern {
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    llvm::errs() << "TestUpdateConsumerType operand: " << operands.front() << "\n";
+    llvm::errs() << "TestUpdateConsumerType operand: " << operands.front()
+                 << "\n";
     // Verify that the incoming operand has been successfully remapped to F64.
     if (!operands[0].getType().isF64())
       return failure();

``````````

</details>


https://github.com/llvm/llvm-project/pull/112141


More information about the Mlir-commits mailing list