[Mlir-commits] [mlir] [mlir][Transforms] Dialect Conversion: Add 1:N support to `remapInput` (PR #131454)
Matthias Springer
llvmlistbot at llvm.org
Sat Mar 15 08:53:20 PDT 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/131454
>From 3bdc7471f62dcd86399b757c94db1a09db9f2bc4 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 15 Mar 2025 14:19:48 +0100
Subject: [PATCH 1/2] [mlir][Transforms] Dialect Conversion: Add 1:N support to
`remapInput`
---
.../mlir/Transforms/DialectConversion.h | 11 ++++---
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 4 +--
.../Transforms/Utils/DialectConversion.cpp | 19 ++++++-----
mlir/test/Transforms/test-legalizer.mlir | 27 +++++++++++++--
mlir/test/lib/Dialect/Test/TestOps.td | 12 +++++--
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 33 ++++++++++++-------
6 files changed, 74 insertions(+), 32 deletions(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f54397e942ae0..93e98bfd169cb 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -65,11 +65,14 @@ class TypeConverter {
SignatureConversion(unsigned numOrigInputs)
: remappedInputs(numOrigInputs) {}
- /// This struct represents a range of new types or a single value that
+ /// This struct represents a range of new types or a range of values that
/// remaps an existing signature input.
struct InputMapping {
size_t inputNo, size;
- Value replacementValue;
+ SmallVector<Value, 1> replacementValues;
+
+ /// Return "true" if this input was replaces with one or multiple values.
+ bool replacedWithValues() const { return !replacementValues.empty(); }
};
/// Return the argument types for the new signature.
@@ -92,9 +95,9 @@ class TypeConverter {
/// used if the new types are not intended to remap an existing input.
void addInputs(ArrayRef<Type> types);
- /// Remap an input of the original signature to another `replacement`
+ /// Remap an input of the original signature to another `replacements`
/// value. This drops the original argument.
- void remapInput(unsigned origInputNo, Value replacement);
+ void remapInput(unsigned origInputNo, ArrayRef<Value> replacements);
private:
/// Remap an input of the original signature with a range of types in the
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 8b6c62ca2e36d..f22ad1fd70db2 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -274,7 +274,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
// and canonicalize that away later.
Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
auto type = cast<MemRefType>(attribution.getType());
- auto descr = MemRefDescriptor::fromStaticShape(
+ Value descr = MemRefDescriptor::fromStaticShape(
rewriter, loc, *getTypeConverter(), type, memory);
signatureConversion.remapInput(numProperArguments + idx, descr);
}
@@ -303,7 +303,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
alignment = alignAttr.getInt();
Value allocated = rewriter.create<LLVM::AllocaOp>(
gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
- auto descr = MemRefDescriptor::fromStaticShape(
+ Value descr = MemRefDescriptor::fromStaticShape(
rewriter, loc, *getTypeConverter(), type, allocated);
signatureConversion.remapInput(
numProperArguments + numWorkgroupAttributions + idx, descr);
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 7bfcd192c9aa9..9779436c947cf 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1341,7 +1341,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
rewriter.getUnknownLoc());
for (unsigned i = 0; i < origArgCount; ++i) {
auto inputMap = signatureConversion.getInputMapping(i);
- if (!inputMap || inputMap->replacementValue)
+ if (!inputMap || inputMap->replacedWithValues())
continue;
Location origLoc = block->getArgument(i).getLoc();
for (unsigned j = 0; j < inputMap->size; ++j)
@@ -1390,12 +1390,12 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
continue;
}
- if (Value repl = inputMap->replacementValue) {
- // This block argument was dropped and a replacement value was provided.
+ if (inputMap->replacedWithValues()) {
+ // This block argument was dropped and replacement values were provided.
assert(inputMap->size == 0 &&
"invalid to provide a replacement value when the argument isn't "
"dropped");
- mapping.map(origArg, repl);
+ mapping.map(origArg, inputMap->replacementValues);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
continue;
}
@@ -2807,14 +2807,15 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
assert(!remappedInputs[origInputNo] && "input has already been remapped");
assert(newInputCount != 0 && "expected valid input count");
remappedInputs[origInputNo] =
- InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
+ InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
}
-void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
- Value replacementValue) {
+void TypeConverter::SignatureConversion::remapInput(
+ unsigned origInputNo, ArrayRef<Value> replacements) {
assert(!remappedInputs[origInputNo] && "input has already been remapped");
- remappedInputs[origInputNo] =
- InputMapping{origInputNo, /*size=*/0, replacementValue};
+ remappedInputs[origInputNo] = InputMapping{
+ origInputNo, /*size=*/0,
+ SmallVector<Value, 1>(replacements.begin(), replacements.end())};
}
LogicalResult TypeConverter::convertType(Type t,
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index ae7d344b7167f..34948ae685f0a 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -472,11 +472,32 @@ func.func @circular_mapping() {
// -----
-func.func @test_1_to_n_block_signature_conversion() {
- "test.duplicate_block_args"() ({
+// CHECK-LABEL: func @test_duplicate_block_arg()
+// CHECK: test.convert_block_args is_legal duplicate {
+// CHECK: ^{{.*}}(%[[arg0:.*]]: i64, %[[arg1:.*]]: i64):
+// CHECK: "test.valid"(%[[arg0]], %[[arg1]])
+// CHECK: }
+func.func @test_duplicate_block_arg() {
+ test.convert_block_args duplicate {
^bb0(%arg0: i64):
"test.repetitive_1_to_n_consumer"(%arg0) : (i64) -> ()
- }) {} : () -> ()
+ } : () -> ()
+ "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: func @test_remap_block_arg()
+// CHECK: %[[repl:.*]] = "test.legal_op"() : () -> i32
+// CHECK: test.convert_block_args %[[repl]] is_legal replace_with_operand {
+// CHECK-NEXT: "test.valid"(%[[repl]], %[[repl]])
+// CHECK: }
+func.func @test_remap_block_arg() {
+ %0 = "test.legal_op"() : () -> (i32)
+ test.convert_block_args %0 replace_with_operand {
+ ^bb0(%arg0: i32):
+ "test.repetitive_1_to_n_consumer"(%arg0) : (i32) -> ()
+ } : (i32) -> ()
"test.return"() : () -> ()
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index b90afe61b8097..94c722038f1cc 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1940,9 +1940,17 @@ def LegalOpC : TEST_Op<"legal_op_c">,
Arguments<(ins I32)>, Results<(outs I32)>;
def LegalOpD : TEST_Op<"legal_op_d">, Arguments<(ins AnyType)>;
-def DuplicateBlockArgsOp : TEST_Op<"duplicate_block_args", [SingleBlock]> {
- let arguments = (ins UnitAttr:$is_legal);
+def ConvertBlockArgsOp : TEST_Op<"convert_block_args", [SingleBlock]> {
+ let arguments = (ins UnitAttr:$is_legal, UnitAttr:$replace_with_operand,
+ UnitAttr:$duplicate, Optional<AnyType>:$val);
let regions = (region SizedRegion<1>:$body);
+ let assemblyFormat = [{
+ $val
+ (`is_legal` $is_legal^)?
+ (`duplicate` $duplicate^)?
+ (`replace_with_operand` $replace_with_operand^)?
+ $body attr-dict `:` functional-type(operands, results)
+ }];
}
// Check that the conversion infrastructure can properly undo the creation of
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index b20e0816bd17c..b868f1a3a08da 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1193,22 +1193,31 @@ class TestEraseOp : public ConversionPattern {
}
};
-/// This pattern matches a test.duplicate_block_args op and duplicates all
-/// block arguments.
-class TestDuplicateBlockArgs
- : public OpConversionPattern<DuplicateBlockArgsOp> {
- using OpConversionPattern<DuplicateBlockArgsOp>::OpConversionPattern;
+/// This pattern matches a test.convert_block_args op. It either:
+/// a) Duplicates all block arguments,
+/// b) or: drops all block arguments and replaces each with 2x the first
+/// operand.
+class TestConvertBlockArgs : public OpConversionPattern<ConvertBlockArgsOp> {
+ using OpConversionPattern<ConvertBlockArgsOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(DuplicateBlockArgsOp op, OpAdaptor adaptor,
+ matchAndRewrite(ConvertBlockArgsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getIsLegal())
return failure();
- rewriter.startOpModification(op);
Block *body = &op.getBody().front();
TypeConverter::SignatureConversion result(body->getNumArguments());
- for (auto it : llvm::enumerate(body->getArgumentTypes()))
- result.addInputs(it.index(), {it.value(), it.value()});
+ for (auto it : llvm::enumerate(body->getArgumentTypes())) {
+ if (op.getReplaceWithOperand()) {
+ result.remapInput(it.index(), {adaptor.getVal(), adaptor.getVal()});
+ } else if (op.getDuplicate()) {
+ result.addInputs(it.index(), {it.value(), it.value()});
+ } else {
+ // No action specified. Pattern does not apply.
+ return failure();
+ }
+ }
+ rewriter.startOpModification(op);
rewriter.applySignatureConversion(body, result, getTypeConverter());
op.setIsLegal(true);
rewriter.finalizeOpModification(op);
@@ -1355,7 +1364,7 @@ struct TestLegalizePatternDriver
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
&getContext(), converter);
- patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
+ patterns.add<TestConvertBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
mlir::populateCallOpTypeConversionPattern(patterns, converter);
@@ -1406,8 +1415,8 @@ struct TestLegalizePatternDriver
target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
[](TestOpInPlaceSelfFold op) { return op.getFolded(); });
- target.addDynamicallyLegalOp<DuplicateBlockArgsOp>(
- [](DuplicateBlockArgsOp op) { return op.getIsLegal(); });
+ target.addDynamicallyLegalOp<ConvertBlockArgsOp>(
+ [](ConvertBlockArgsOp op) { return op.getIsLegal(); });
// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
>From daf96e10470c620c9c2e1c49b89fdfe3e55c597c Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sat, 15 Mar 2025 16:53:13 +0100
Subject: [PATCH 2/2] Update mlir/include/mlir/Transforms/DialectConversion.h
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>
---
mlir/include/mlir/Transforms/DialectConversion.h | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 93e98bfd169cb..8a70883293d91 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -95,8 +95,8 @@ class TypeConverter {
/// used if the new types are not intended to remap an existing input.
void addInputs(ArrayRef<Type> types);
- /// Remap an input of the original signature to another `replacements`
- /// value. This drops the original argument.
+ /// Remap an input of the original signature to `replacements`
+ /// values. This drops the original argument.
void remapInput(unsigned origInputNo, ArrayRef<Value> replacements);
private:
More information about the Mlir-commits
mailing list