[Mlir-commits] [mlir] [mlir][Transforms] Dialect Conversion: Add 1:N support to `remapInput` (PR #131454)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 15 06:22:55 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit adds 1:N support to `SignatureConversion::remapInputs`. This API allows users to replace a block argument with multiple replacement values. (And the block argument is dropped.) The API already supported "bbarg --> multiple bbargs" mappings, but "bbarg --> multiple SSA values" was missing.


---
Full diff: https://github.com/llvm/llvm-project/pull/131454.diff


6 Files Affected:

- (modified) mlir/include/mlir/Transforms/DialectConversion.h (+7-4) 
- (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp (+2-2) 
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+10-9) 
- (modified) mlir/test/Transforms/test-legalizer.mlir (+24-3) 
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+10-2) 
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+21-12) 


``````````diff
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f54397e942ae0..3e84331ffc1c5 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -65,11 +65,14 @@ class TypeConverter {
     SignatureConversion(unsigned numOrigInputs)
         : remappedInputs(numOrigInputs) {}
 
-    /// This struct represents a range of new types or a single value that
+    /// This struct represents a range of new types or a range of values that
     /// remaps an existing signature input.
     struct InputMapping {
       size_t inputNo, size;
-      Value replacementValue;
+      SmallVector<Value, 1> replacementValues;
+
+      /// Return "true" if this input was replace with one or multiple values.
+      bool replacedWithValues() const { return !replacementValues.empty(); }
     };
 
     /// Return the argument types for the new signature.
@@ -92,9 +95,9 @@ class TypeConverter {
     /// used if the new types are not intended to remap an existing input.
     void addInputs(ArrayRef<Type> types);
 
-    /// Remap an input of the original signature to another `replacement`
+    /// Remap an input of the original signature to another `replacements`
     /// value. This drops the original argument.
-    void remapInput(unsigned origInputNo, Value replacement);
+    void remapInput(unsigned origInputNo, ArrayRef<Value> replacements);
 
   private:
     /// Remap an input of the original signature with a range of types in the
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 8b6c62ca2e36d..f22ad1fd70db2 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -274,7 +274,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
         // and canonicalize that away later.
         Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
         auto type = cast<MemRefType>(attribution.getType());
-        auto descr = MemRefDescriptor::fromStaticShape(
+        Value descr = MemRefDescriptor::fromStaticShape(
             rewriter, loc, *getTypeConverter(), type, memory);
         signatureConversion.remapInput(numProperArguments + idx, descr);
       }
@@ -303,7 +303,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
         alignment = alignAttr.getInt();
       Value allocated = rewriter.create<LLVM::AllocaOp>(
           gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
-      auto descr = MemRefDescriptor::fromStaticShape(
+      Value descr = MemRefDescriptor::fromStaticShape(
           rewriter, loc, *getTypeConverter(), type, allocated);
       signatureConversion.remapInput(
           numProperArguments + numWorkgroupAttributions + idx, descr);
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 7bfcd192c9aa9..9779436c947cf 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1341,7 +1341,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
                                 rewriter.getUnknownLoc());
   for (unsigned i = 0; i < origArgCount; ++i) {
     auto inputMap = signatureConversion.getInputMapping(i);
-    if (!inputMap || inputMap->replacementValue)
+    if (!inputMap || inputMap->replacedWithValues())
       continue;
     Location origLoc = block->getArgument(i).getLoc();
     for (unsigned j = 0; j < inputMap->size; ++j)
@@ -1390,12 +1390,12 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       continue;
     }
 
-    if (Value repl = inputMap->replacementValue) {
-      // This block argument was dropped and a replacement value was provided.
+    if (inputMap->replacedWithValues()) {
+      // This block argument was dropped and replacement values were provided.
       assert(inputMap->size == 0 &&
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
-      mapping.map(origArg, repl);
+      mapping.map(origArg, inputMap->replacementValues);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
       continue;
     }
@@ -2807,14 +2807,15 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
   assert(!remappedInputs[origInputNo] && "input has already been remapped");
   assert(newInputCount != 0 && "expected valid input count");
   remappedInputs[origInputNo] =
-      InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
+      InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
 }
 
-void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
-                                                    Value replacementValue) {
+void TypeConverter::SignatureConversion::remapInput(
+    unsigned origInputNo, ArrayRef<Value> replacements) {
   assert(!remappedInputs[origInputNo] && "input has already been remapped");
-  remappedInputs[origInputNo] =
-      InputMapping{origInputNo, /*size=*/0, replacementValue};
+  remappedInputs[origInputNo] = InputMapping{
+      origInputNo, /*size=*/0,
+      SmallVector<Value, 1>(replacements.begin(), replacements.end())};
 }
 
 LogicalResult TypeConverter::convertType(Type t,
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index ae7d344b7167f..34948ae685f0a 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -472,11 +472,32 @@ func.func @circular_mapping() {
 
 // -----
 
-func.func @test_1_to_n_block_signature_conversion() {
-  "test.duplicate_block_args"() ({
+// CHECK-LABEL: func @test_duplicate_block_arg()
+//       CHECK:   test.convert_block_args  is_legal duplicate {
+//       CHECK:   ^{{.*}}(%[[arg0:.*]]: i64, %[[arg1:.*]]: i64):
+//       CHECK:     "test.valid"(%[[arg0]], %[[arg1]])
+//       CHECK:   }
+func.func @test_duplicate_block_arg() {
+  test.convert_block_args duplicate {
   ^bb0(%arg0: i64):
     "test.repetitive_1_to_n_consumer"(%arg0) : (i64) -> ()
-  }) {} : () -> ()
+  } : () -> ()
+  "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: func @test_remap_block_arg()
+//       CHECK:      %[[repl:.*]] = "test.legal_op"() : () -> i32
+//       CHECK:      test.convert_block_args %[[repl]]  is_legal replace_with_operand {
+//       CHECK-NEXT:   "test.valid"(%[[repl]], %[[repl]])
+//       CHECK:      }
+func.func @test_remap_block_arg() {
+  %0 = "test.legal_op"() : () -> (i32)
+  test.convert_block_args %0 replace_with_operand {
+  ^bb0(%arg0: i32):
+    "test.repetitive_1_to_n_consumer"(%arg0) : (i32) -> ()
+  } : (i32) -> ()
   "test.return"() : () -> ()
 }
 
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index b90afe61b8097..94c722038f1cc 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1940,9 +1940,17 @@ def LegalOpC : TEST_Op<"legal_op_c">,
   Arguments<(ins I32)>, Results<(outs I32)>;
 def LegalOpD : TEST_Op<"legal_op_d">, Arguments<(ins AnyType)>;
 
-def DuplicateBlockArgsOp : TEST_Op<"duplicate_block_args", [SingleBlock]> {
-  let arguments = (ins UnitAttr:$is_legal);
+def ConvertBlockArgsOp : TEST_Op<"convert_block_args", [SingleBlock]> {
+  let arguments = (ins UnitAttr:$is_legal, UnitAttr:$replace_with_operand,
+                       UnitAttr:$duplicate, Optional<AnyType>:$val);
   let regions = (region SizedRegion<1>:$body);
+  let assemblyFormat = [{
+    $val
+    (`is_legal` $is_legal^)?
+    (`duplicate` $duplicate^)?
+    (`replace_with_operand` $replace_with_operand^)?
+    $body attr-dict `:` functional-type(operands, results)
+  }];
 }
 
 // Check that the conversion infrastructure can properly undo the creation of
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index b20e0816bd17c..d27fc097337ff 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1193,22 +1193,31 @@ class TestEraseOp : public ConversionPattern {
   }
 };
 
-/// This pattern matches a test.duplicate_block_args op and duplicates all
-/// block arguments.
-class TestDuplicateBlockArgs
-    : public OpConversionPattern<DuplicateBlockArgsOp> {
-  using OpConversionPattern<DuplicateBlockArgsOp>::OpConversionPattern;
+/// This pattern matches a test.convert_block_args op. It either:
+/// a) Duplicates all block arguments,
+/// b) or: drops all block arguments and replaces each with a 2x the first
+///    operand.
+class TestConvertBlockArgs : public OpConversionPattern<ConvertBlockArgsOp> {
+  using OpConversionPattern<ConvertBlockArgsOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(DuplicateBlockArgsOp op, OpAdaptor adaptor,
+  matchAndRewrite(ConvertBlockArgsOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (op.getIsLegal())
       return failure();
-    rewriter.startOpModification(op);
     Block *body = &op.getBody().front();
     TypeConverter::SignatureConversion result(body->getNumArguments());
-    for (auto it : llvm::enumerate(body->getArgumentTypes()))
-      result.addInputs(it.index(), {it.value(), it.value()});
+    for (auto it : llvm::enumerate(body->getArgumentTypes())) {
+      if (op.getReplaceWithOperand()) {
+        result.remapInput(it.index(), {adaptor.getVal(), adaptor.getVal()});
+      } else if (op.getDuplicate()) {
+        result.addInputs(it.index(), {it.value(), it.value()});
+      } else {
+        // No action specified. Pattern does not apply.
+        return failure();
+      }
+    }
+    rewriter.startOpModification(op);
     rewriter.applySignatureConversion(body, result, getTypeConverter());
     op.setIsLegal(true);
     rewriter.finalizeOpModification(op);
@@ -1355,7 +1364,7 @@ struct TestLegalizePatternDriver
     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
                  TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
         &getContext(), converter);
-    patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
+    patterns.add<TestConvertBlockArgs>(converter, &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
     mlir::populateCallOpTypeConversionPattern(patterns, converter);
@@ -1406,8 +1415,8 @@ struct TestLegalizePatternDriver
     target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
         [](TestOpInPlaceSelfFold op) { return op.getFolded(); });
 
-    target.addDynamicallyLegalOp<DuplicateBlockArgsOp>(
-        [](DuplicateBlockArgsOp op) { return op.getIsLegal(); });
+    target.addDynamicallyLegalOp<ConvertBlockArgsOp>(
+        [](ConvertBlockArgsOp op) { return op.getIsLegal(); });
 
     // Handle a partial conversion.
     if (mode == ConversionMode::Partial) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/131454


More information about the Mlir-commits mailing list