[Mlir-commits] [mlir] [mlir][Transforms] Dialect Conversion: No target mat. for 1:N replacement (PR #117513)
Matthias Springer
llvmlistbot at llvm.org
Sun Dec 15 11:29:36 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/117513
>From bf0d13553b2bc2124a266e398976ba80a1114580 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 14 Dec 2024 16:34:47 +0100
Subject: [PATCH 1/3] [mlir][Vector] Move mask materialization patterns to
greedy rewrite
The mask materialization patterns during `VectorToLLVM` are rewrite patterns. They should run as part of the greedy pattern rewrite and not the dialect conversion. (Rewrite patterns and conversion patterns are not generally compatible.)
The current combination of rewrite patterns and conversion patterns triggered an edge case when merging the 1:1 and 1:N dialect conversions.
---
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 7 +-
.../VectorToLLVM/vector-mask-to-llvm.mlir | 4 +-
.../VectorToLLVM/vector-to-llvm.mlir | 4 +-
.../VectorToLLVM/vector-xfer-to-llvm.mlir | 80 +++++++++----------
4 files changed, 44 insertions(+), 51 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 4623b9667998cc..64a9ad8e9bade0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -61,8 +61,8 @@ struct ConvertVectorToLLVMPass
} // namespace
void ConvertVectorToLLVMPass::runOnOperation() {
- // Perform progressive lowering of operations on slices and
- // all contraction operations. Also applies folding and DCE.
+ // Perform progressive lowering of operations on slices and all contraction
+ // operations. Also materializes masks, applies folding and DCE.
{
RewritePatternSet patterns(&getContext());
populateVectorToVectorCanonicalizationPatterns(patterns);
@@ -76,6 +76,8 @@ void ConvertVectorToLLVMPass::runOnOperation() {
VectorTransformsOptions());
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
+ populateVectorMaskMaterializationPatterns(patterns,
+ force32BitVectorIndices);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
@@ -83,7 +85,6 @@ void ConvertVectorToLLVMPass::runOnOperation() {
LowerToLLVMOptions options(&getContext());
LLVMTypeConverter converter(&getContext(), options);
RewritePatternSet patterns(&getContext());
- populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices);
populateVectorTransferLoweringPatterns(patterns);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
index 82351eb7c98a43..91e5358622b69d 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
@@ -7,7 +7,7 @@
// CMP32: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i32
// CMP32: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi32>
// CMP32: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<11xi32>
-// CMP32: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<11xi32>
+// CMP32: %[[T4:.*]] = arith.cmpi sgt, %[[T3]], %[[T0]] : vector<11xi32>
// CMP32: return %[[T4]] : vector<11xi1>
// CMP64-LABEL: @genbool_var_1d(
@@ -16,7 +16,7 @@
// CMP64: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i64
// CMP64: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi64>
// CMP64: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<11xi64>
-// CMP64: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<11xi64>
+// CMP64: %[[T4:.*]] = arith.cmpi sgt, %[[T3]], %[[T0]] : vector<11xi64>
// CMP64: return %[[T4]] : vector<11xi1>
func.func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 2473fe933ffcb2..ea88fece9e662d 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -3097,7 +3097,7 @@ func.func @create_mask_0d(%num_elems : index) -> vector<i1> {
// CHECK: %[[NUM_ELEMS_i32:.*]] = arith.index_cast %[[NUM_ELEMS]] : index to i32
// CHECK: %[[BOUNDS:.*]] = llvm.insertelement %[[NUM_ELEMS_i32]]
// CHECK: %[[BOUNDS_CAST:.*]] = builtin.unrealized_conversion_cast %[[BOUNDS]] : vector<1xi32> to vector<i32>
-// CHECK: %[[RESULT:.*]] = arith.cmpi slt, %[[INDICES]], %[[BOUNDS_CAST]] : vector<i32>
+// CHECK: %[[RESULT:.*]] = arith.cmpi sgt, %[[BOUNDS_CAST]], %[[INDICES]] : vector<i32>
// CHECK: return %[[RESULT]] : vector<i1>
// -----
@@ -3113,7 +3113,7 @@ func.func @create_mask_1d(%num_elems : index) -> vector<4xi1> {
// CHECK: %[[NUM_ELEMS_i32:.*]] = arith.index_cast %[[NUM_ELEMS]] : index to i32
// CHECK: %[[BOUNDS_INSERT:.*]] = llvm.insertelement %[[NUM_ELEMS_i32]]
// CHECK: %[[BOUNDS:.*]] = llvm.shufflevector %[[BOUNDS_INSERT]]
-// CHECK: %[[RESULT:.*]] = arith.cmpi slt, %[[INDICES]], %[[BOUNDS]] : vector<4xi32>
+// CHECK: %[[RESULT:.*]] = arith.cmpi sgt, %[[BOUNDS]], %[[INDICES]] : vector<4xi32>
// CHECK: return %[[RESULT]] : vector<4xi1>
// -----
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir
index 8f01cc2b8d44c3..d3f6d7eca90b41 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir
@@ -14,30 +14,28 @@ func.func @transfer_read_write_1d(%A : memref<?xf32>, %base: index) -> vector<17
// CHECK-LABEL: func @transfer_read_write_1d
// CHECK-SAME: %[[MEM:.*]]: memref<?xf32>,
// CHECK-SAME: %[[BASE:.*]]: index) -> vector<17xf32>
-// CHECK: %[[C7:.*]] = arith.constant 7.0
-//
-// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
-// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE]] : index
+// 1. Create pass-through vector.
+// CHECK-DAG: %[[PASS_THROUGH:.*]] = arith.constant dense<7.000000e+00> : vector<17xf32>
//
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
-// CHECK: %[[linearIndex:.*]] = arith.constant dense
+// CHECK-DAG: %[[linearIndex:.*]] = arith.constant dense
// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<17x[[$IDX_TYPE]]>
//
-// 3. Create bound vector to compute in-bound mask:
+// 3. Let dim be the memref dimension, compute the in-bound index (dim - offset)
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
+// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE]] : index
+//
+// 4. Create bound vector to compute in-bound mask:
// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
// CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] :
// CMP32-SAME: index to i32
// CMP64-SAME: index to i64
// CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
// CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
-// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]] : vector<17x[[$IDX_TYPE]]>
+// CHECK: %[[mask:.*]] = arith.cmpi sgt, %[[boundVect]], %[[linearIndex]] : vector<17x[[$IDX_TYPE]]>
// CMP64-SAME: : vector<17xi64>
//
-// 4. Create pass-through vector.
-// CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<17xf32>
-//
// 5. Bitcast to vector form.
// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} :
// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
@@ -48,28 +46,23 @@ func.func @transfer_read_write_1d(%A : memref<?xf32>, %base: index) -> vector<17
// CHECK-SAME: -> vector<17xf32>
//
// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
-// CHECK: %[[C0_b:.*]] = arith.constant 0 : index
-// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref<?xf32>
+// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
// CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index
//
-// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
-// CHECK: %[[linearIndex_b:.*]] = arith.constant dense
-// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<17x[[$IDX_TYPE]]>
-//
-// 3. Create bound vector to compute in-bound mask:
+// 2. Create bound vector to compute in-bound mask:
// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
// CHECK: %[[btrunc_b:.*]] = arith.index_cast %[[BOUND_b]]
// CMP32-SAME: index to i32
// CHECK: %[[boundVecInsert_b:.*]] = llvm.insertelement %[[btrunc_b]]
// CHECK: %[[boundVect_b:.*]] = llvm.shufflevector %[[boundVecInsert_b]]
-// CHECK: %[[mask_b:.*]] = arith.cmpi slt, %[[linearIndex_b]],
-// CHECK-SAME: %[[boundVect_b]] : vector<17x[[$IDX_TYPE]]>
+// CHECK: %[[mask_b:.*]] = arith.cmpi sgt, %[[boundVect_b]],
+// CHECK-SAME: %[[linearIndex]] : vector<17x[[$IDX_TYPE]]>
//
-// 4. Bitcast to vector form.
+// 3. Bitcast to vector form.
// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
//
-// 5. Rewrite as a masked write.
+// 4. Rewrite as a masked write.
// CHECK: llvm.intr.masked.store %[[loaded]], %[[gep_b]], %[[mask_b]]
// CHECK-SAME: {alignment = 4 : i32} :
// CHECK-SAME: vector<17xf32>, vector<17xi1> into !llvm.ptr
@@ -87,17 +80,18 @@ func.func @transfer_read_write_1d_scalable(%A : memref<?xf32>, %base: index) ->
// CHECK-LABEL: func @transfer_read_write_1d_scalable
// CHECK-SAME: %[[MEM:.*]]: memref<?xf32>,
// CHECK-SAME: %[[BASE:.*]]: index) -> vector<[17]xf32>
-// CHECK: %[[C7:.*]] = arith.constant 7.0
+// 1. Create pass-through vector.
+// CHECK-DAG: %[[PASS_THROUGH:.*]] = arith.constant dense<7.000000e+00> : vector<[17]xf32>
//
-// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// 2. Let dim be the memref dimension, compute the in-bound index (dim - offset)
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE]] : index
//
-// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
+// 3. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// CHECK: %[[linearIndex:.*]] = llvm.intr.stepvector : vector<[17]x[[$IDX_TYPE]]>
//
-// 3. Create bound vector to compute in-bound mask:
+// 4. Create bound vector to compute in-bound mask:
// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
// CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to [[$IDX_TYPE]]
// CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
@@ -105,9 +99,6 @@ func.func @transfer_read_write_1d_scalable(%A : memref<?xf32>, %base: index) ->
// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
// CHECK-SAME: : vector<[17]x[[$IDX_TYPE]]>
//
-// 4. Create pass-through vector.
-// CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<[17]xf32>
-//
// 5. Bitcast to vector form.
// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} :
// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
@@ -118,8 +109,7 @@ func.func @transfer_read_write_1d_scalable(%A : memref<?xf32>, %base: index) ->
// CHECK-SAME: -> vector<[17]xf32>
//
// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
-// CHECK: %[[C0_b:.*]] = arith.constant 0 : index
-// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref<?xf32>
+// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
// CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index
//
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
@@ -197,23 +187,23 @@ func.func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: i
}
// CHECK-LABEL: func @transfer_read_2d_to_1d
// CHECK-SAME: %[[BASE_0:[a-zA-Z0-9]*]]: index, %[[BASE_1:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
-// CHECK: %[[c1:.*]] = arith.constant 1 : index
+//
+// Create a vector with linear indices [ 0 .. vector_length - 1 ].
+// CHECK-DAG: %[[linearIndex:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
+// CHECK-SAME: vector<17x[[$IDX_TYPE]]>
+//
+// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c1]] : memref<?x?xf32>
//
// Compute the in-bound index (dim - offset)
// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE_1]] : index
//
-// Create a vector with linear indices [ 0 .. vector_length - 1 ].
-// CHECK: %[[linearIndex:.*]] = arith.constant dense
-// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
-// CHECK-SAME: vector<17x[[$IDX_TYPE]]>
-//
// Create bound vector to compute in-bound mask:
// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
// CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to [[$IDX_TYPE]]
// CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
// CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
-// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
+// CHECK: %[[mask:.*]] = arith.cmpi sgt, %[[boundVect]], %[[linearIndex]]
func.func @transfer_read_2d_to_1d_scalable(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<[17]xf32> {
%f7 = arith.constant 7.0: f32
@@ -255,12 +245,13 @@ func.func @transfer_read_write_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %bas
// CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
//
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+//
// 1. Check address space for GEP is correct.
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
// CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
//
// 2. Check address space of the memref is correct.
-// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c0]] : memref<?xf32, 3>
//
// 3. Check address space for GEP is correct.
@@ -280,12 +271,13 @@ func.func @transfer_read_write_1d_non_zero_addrspace_scalable(%A : memref<?xf32,
// CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace_scalable
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<[17]xf32>
//
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+//
// 1. Check address space for GEP is correct.
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
// CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
//
// 2. Check address space of the memref is correct.
-// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c0]] : memref<?xf32, 3>
//
// 3. Check address space for GEP is correct.
@@ -330,10 +322,10 @@ func.func @transfer_read_1d_inbounds_scalable(%A : memref<?xf32>, %base: index)
// CHECK-LABEL: func @transfer_read_write_1d_mask
// CHECK: %[[mask1:.*]] = arith.constant dense<[false, false, true, false, true]>
-// CHECK: %[[cmpi:.*]] = arith.cmpi slt
+// CHECK: %[[cmpi:.*]] = arith.cmpi sgt
// CHECK: %[[mask2:.*]] = arith.andi %[[cmpi]], %[[mask1]]
// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask2]]
-// CHECK: %[[cmpi_1:.*]] = arith.cmpi slt
+// CHECK: %[[cmpi_1:.*]] = arith.cmpi sgt
// CHECK: %[[mask3:.*]] = arith.andi %[[cmpi_1]], %[[mask1]]
// CHECK: llvm.intr.masked.store %[[r]], %{{.*}}, %[[mask3]]
// CHECK: return %[[r]]
>From e5926b63835eec731c513d91ce9c451c429ca572 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 14 Dec 2024 17:19:18 +0100
Subject: [PATCH 2/3] [mlir][Vector] Clean up
`populateVectorToLLVMConversionPatterns`
---
.../Vector/Transforms/LoweringPatterns.h | 4 +++
.../GPUCommon/GPUToLLVMConversion.cpp | 12 +++++++++
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 27 ++++++++++---------
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 6 ++++-
.../Conversion/GPUCommon/lower-vector.mlir | 4 +--
.../VectorToLLVM/vector-to-llvm.mlir | 5 ----
6 files changed, 37 insertions(+), 21 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 3d643c96b45008..c507b23c6d4de6 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -292,6 +292,10 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
int64_t targetRank = 1,
PatternBenefit benefit = 1);
+/// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where
+/// n > 1.
+void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
+
} // namespace vector
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 1497d662dcdbdd..2fe3b1302e5e5b 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -32,10 +32,12 @@
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Error.h"
@@ -522,6 +524,16 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
void GpuToLLVMConversionPass::runOnOperation() {
MLIRContext *context = &getContext();
+
+ // Perform progressive lowering of vector transfer operations.
+ {
+ RewritePatternSet patterns(&getContext());
+ // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
+ vector::populateVectorTransferLoweringPatterns(patterns,
+ /*maxTransferRank=*/1);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+
LowerToLLVMOptions options(context);
options.useBarePtrCallConv = hostBarePtrCallConv;
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a9a07c323c7358..577b74bb7e0c26 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1475,16 +1475,16 @@ class VectorTypeCastOpConversion
/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
/// Non-scalable versions of this operation are handled in Vector Transforms.
-class VectorCreateMaskOpRewritePattern
- : public OpRewritePattern<vector::CreateMaskOp> {
+class VectorCreateMaskOpConversion
+ : public OpConversionPattern<vector::CreateMaskOp> {
public:
- explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
+ explicit VectorCreateMaskOpConversion(MLIRContext *context,
bool enableIndexOpt)
- : OpRewritePattern<vector::CreateMaskOp>(context),
+ : OpConversionPattern<vector::CreateMaskOp>(context),
force32BitVectorIndices(enableIndexOpt) {}
- LogicalResult matchAndRewrite(vector::CreateMaskOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto dstType = op.getType();
if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
return failure();
@@ -1495,7 +1495,7 @@ class VectorCreateMaskOpRewritePattern
loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
/*isScalable=*/true));
auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
- op.getOperand(0));
+ adaptor.getOperands()[0]);
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
indices, bounds);
@@ -1896,16 +1896,19 @@ struct VectorScalableStepOpLowering
} // namespace
+void mlir::vector::populateVectorRankReducingFMAPattern(
+ RewritePatternSet &patterns) {
+ patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
+}
+
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions, bool force32BitVectorIndices) {
+ // This function populates only ConversionPatterns, not RewritePatterns.
MLIRContext *ctx = converter.getDialect()->getContext();
- patterns.add<VectorFMAOpNDRewritePattern>(ctx);
- populateVectorInsertExtractStridedSliceTransforms(patterns);
- populateVectorStepLoweringPatterns(patterns);
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
- patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
+ patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
@@ -1922,8 +1925,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
MaskedReductionOpConversion, VectorInterleaveOpLowering,
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
VectorScalableStepOpLowering>(converter);
- // Transfer ops with rank > 1 are handled by VectorToSCF.
- populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 64a9ad8e9bade0..2d94c2f2e85a08 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -62,7 +62,8 @@ struct ConvertVectorToLLVMPass
void ConvertVectorToLLVMPass::runOnOperation() {
// Perform progressive lowering of operations on slices and all contraction
- // operations. Also materializes masks, applies folding and DCE.
+ // operations. Also materializes masks, lowers vector.step, rank-reduces FMA,
+ // applies folding and DCE.
{
RewritePatternSet patterns(&getContext());
populateVectorToVectorCanonicalizationPatterns(patterns);
@@ -78,6 +79,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
populateVectorMaskMaterializationPatterns(patterns,
force32BitVectorIndices);
+ populateVectorInsertExtractStridedSliceTransforms(patterns);
+ populateVectorStepLoweringPatterns(patterns);
+ populateVectorRankReducingFMAPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/test/Conversion/GPUCommon/lower-vector.mlir b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
index 44deb45cd752b4..532a2383cea9ef 100644
--- a/mlir/test/Conversion/GPUCommon/lower-vector.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
@@ -1,11 +1,11 @@
// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
module {
- func.func @func(%arg: vector<11xf32>) {
+ func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
%cst_41 = arith.constant dense<true> : vector<11xi1>
// CHECK: vector.mask
// CHECK-SAME: vector.yield %arg0
%127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
- return
+ return %127 : vector<11xf32>
}
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index ea88fece9e662d..f95e943250bd44 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2046,7 +2046,6 @@ func.func @extract_strided_slice_f32_2d_from_2d_scalable(%arg0: vector<4x[8]xf32
// CHECK-LABEL: @extract_strided_slice_f32_2d_from_2d_scalable(
// CHECK-SAME: %[[ARG:.*]]: vector<4x[8]xf32>)
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
-// CHECK: %[[T2:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[T3:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32>
// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>>
// CHECK: %[[T5:.*]] = llvm.extractvalue %[[T1]][2] : !llvm.array<4 x vector<[8]xf32>>
@@ -2067,7 +2066,6 @@ func.func @insert_strided_slice_f32_2d_into_3d(%b: vector<4x4xf32>, %c: vector<4
return %0 : vector<4x4x4xf32>
}
// CHECK-LABEL: @insert_strided_slice_f32_2d_into_3d
-// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xf32>>>
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xf32>>>
// -----
@@ -2077,7 +2075,6 @@ func.func @insert_strided_slice_f32_2d_into_3d_scalable(%b: vector<4x[4]xf32>, %
return %0 : vector<4x4x[4]xf32>
}
// CHECK-LABEL: @insert_strided_slice_f32_2d_into_3d_scalable
-// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xf32>>>
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xf32>>>
// -----
@@ -2087,7 +2084,6 @@ func.func @insert_strided_index_slice_index_2d_into_3d(%b: vector<4x4xindex>, %c
return %0 : vector<4x4x4xindex>
}
// CHECK-LABEL: @insert_strided_index_slice_index_2d_into_3d
-// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>>
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>>
// -----
@@ -2097,7 +2093,6 @@ func.func @insert_strided_index_slice_index_2d_into_3d_scalable(%b: vector<4x[4]
return %0 : vector<4x4x[4]xindex>
}
// CHECK-LABEL: @insert_strided_index_slice_index_2d_into_3d_scalable
-// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xi64>>>
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xi64>>>
// -----
>From ddc92940e9a1490f5a26d77576f4b663253e3cdd Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 25 Nov 2024 04:13:24 +0100
Subject: [PATCH 3/3] [mlir][Transforms] Dialect Conversion: Do not build
target mat. during 1:N replacement
fix test
experiement
---
.../Conversion/LLVMCommon/TypeConverter.cpp | 130 ++++++++++++------
.../Transforms/Utils/DialectConversion.cpp | 46 ++-----
mlir/test/Transforms/test-legalizer.mlir | 8 +-
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 47 +++----
4 files changed, 128 insertions(+), 103 deletions(-)
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 59b0f5c9b09bcd..e2ab0ed6f66cc5 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,70 +153,112 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
type.isVarArg());
});
+ // Add generic source and target materializations to handle cases where
+ // non-LLVM types persist after an LLVM conversion.
+ addSourceMaterialization([&](OpBuilder &builder, Type resultType,
+ ValueRange inputs, Location loc) {
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ .getResult(0);
+ });
+ addTargetMaterialization([&](OpBuilder &builder, Type resultType,
+ ValueRange inputs, Location loc) {
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ .getResult(0);
+ });
+
// Helper function that checks if the given value range is a bare pointer.
auto isBarePointer = [](ValueRange values) {
return values.size() == 1 &&
isa<LLVM::LLVMPointerType>(values.front().getType());
};
- // Argument materializations convert from the new block argument types
- // (multiple SSA values that make up a memref descriptor) back to the
- // original block argument type. The dialect conversion framework will then
- // insert a target materialization from the original block argument type to
- // a legal type.
- addArgumentMaterialization([&](OpBuilder &builder,
- UnrankedMemRefType resultType,
- ValueRange inputs, Location loc) {
+ // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
+ // must be passed explicitly.
+ auto packUnrankedMemRefDesc =
+ [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
+ Location loc, LLVMTypeConverter &converter) -> Value {
// Note: Bare pointers are not supported for unranked memrefs because a
// memref descriptor cannot be built just from a bare pointer.
- if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
+ if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
return Value();
- Value desc =
- UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+ return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
+ inputs);
+ };
+
+ // MemRef descriptor elements -> UnrankedMemRefType
+ auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
+ UnrankedMemRefType resultType,
+ ValueRange inputs, Location loc) {
// An argument materialization must return a value of type
// `resultType`, so insert a cast from the memref descriptor type
// (!llvm.struct) to the original memref type.
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
- .getResult(0);
- });
- addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
- ValueRange inputs, Location loc) {
- Value desc;
- if (isBarePointer(inputs)) {
- desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
- inputs[0]);
- } else if (TypeRange(inputs) ==
- getMemRefDescriptorFields(resultType,
- /*unpackAggregates=*/true)) {
- desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
- } else {
- // The inputs are neither a bare pointer nor an unpacked memref
- // descriptor. This materialization function cannot be used.
+ Value packed =
+ packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this);
+ if (!packed)
return Value();
- }
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
+ .getResult(0);
+ };
+
+ // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
+ // must be passed explicitly.
+ auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType,
+ ValueRange inputs, Location loc,
+ LLVMTypeConverter &converter) -> Value {
+ assert(resultType && "expected non-null result type");
+ if (isBarePointer(inputs))
+ return MemRefDescriptor::fromStaticShape(builder, loc, converter,
+ resultType, inputs[0]);
+ if (TypeRange(inputs) ==
+ converter.getMemRefDescriptorFields(resultType,
+ /*unpackAggregates=*/true))
+ return MemRefDescriptor::pack(builder, loc, converter, resultType,
+ inputs);
+ // The inputs are neither a bare pointer nor an unpacked memref descriptor.
+ // This materialization function cannot be used.
+ return Value();
+ };
+
+ // MemRef descriptor elements -> MemRefType
+ auto rankedMemRefMaterialization = [&](OpBuilder &builder,
+ MemRefType resultType,
+ ValueRange inputs, Location loc) {
// An argument materialization must return a value of type `resultType`,
// so insert a cast from the memref descriptor type (!llvm.struct) to the
// original memref type.
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
- .getResult(0);
- });
- // Add generic source and target materializations to handle cases where
- // non-LLVM types persist after an LLVM conversion.
- addSourceMaterialization([&](OpBuilder &builder, Type resultType,
- ValueRange inputs, Location loc) {
- if (inputs.size() != 1)
+ Value packed =
+ packRankedMemRefDesc(builder, resultType, inputs, loc, *this);
+ if (!packed)
return Value();
-
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
.getResult(0);
- });
+ };
+
+ // Argument materializations convert from the new block argument types
+ // (multiple SSA values that make up a memref descriptor) back to the
+ // original block argument type.
+ addArgumentMaterialization(unrakedMemRefMaterialization);
+ addArgumentMaterialization(rankedMemRefMaterialization);
+ addSourceMaterialization(unrakedMemRefMaterialization);
+ addSourceMaterialization(rankedMemRefMaterialization);
+
+ // Bare pointer -> Packed MemRef descriptor
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
- ValueRange inputs, Location loc) {
- if (inputs.size() != 1)
+ ValueRange inputs, Location loc,
+ Type originalType) -> Value {
+ // The original MemRef type is required to build a MemRef descriptor
+ // because the sizes/strides of the MemRef cannot be inferred from just the
+ // bare pointer.
+ if (!originalType)
return Value();
-
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
- .getResult(0);
+ if (resultType != convertType(originalType))
+ return Value();
+ if (auto memrefType = dyn_cast<MemRefType>(originalType))
+ return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this);
+ if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
+ return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc,
+ *this);
+ return Value();
});
// Integer memory spaces map to themselves.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 1607740a1ee076..51686646a0a2fc 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -849,8 +849,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// function will be deleted when full 1:N support has been added.
///
/// This function inserts an argument materialization back to the original
- /// type, followed by a target materialization to the legalized type (if
- /// applicable).
+ /// type.
void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
ValueRange replacements, Value originalValue,
const TypeConverter *converter);
@@ -1376,9 +1375,13 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// used as a replacement.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
- insertNTo1Materialization(
- OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
- /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
+ if (replArgs.size() == 1) {
+ mapping.map(origArg, replArgs.front());
+ } else {
+ insertNTo1Materialization(
+ OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
+ /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
+ }
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
}
@@ -1437,36 +1440,12 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
// Insert argument materialization back to the original type.
Type originalType = originalValue.getType();
UnrealizedConversionCastOp argCastOp;
- Value argMat = buildUnresolvedMaterialization(
+ buildUnresolvedMaterialization(
MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
- /*inputs=*/replacements, originalType, /*originalType=*/Type(), converter,
- &argCastOp);
+ /*inputs=*/replacements, originalType,
+ /*originalType=*/Type(), converter, &argCastOp);
if (argCastOp)
nTo1TempMaterializations.insert(argCastOp);
-
- // Insert target materialization to the legalized type.
- Type legalOutputType;
- if (converter) {
- legalOutputType = converter->convertType(originalType);
- } else if (replacements.size() == 1) {
- // When there is no type converter, assume that the replacement value
- // types are legal. This is reasonable to assume because they were
- // specified by the user.
- // FIXME: This won't work for 1->N conversions because multiple output
- // types are not supported in parts of the dialect conversion. In such a
- // case, we currently use the original value type.
- legalOutputType = replacements[0].getType();
- }
- if (legalOutputType && legalOutputType != originalType) {
- UnrealizedConversionCastOp targetCastOp;
- buildUnresolvedMaterialization(
- MaterializationKind::Target, computeInsertPoint(argMat), loc,
- /*valueToMap=*/argMat, /*inputs=*/argMat,
- /*outputType=*/legalOutputType, /*originalType=*/originalType,
- converter, &targetCastOp);
- if (targetCastOp)
- nTo1TempMaterializations.insert(targetCastOp);
- }
}
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
@@ -2864,6 +2843,9 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
LogicalResult TypeConverter::convertType(Type t,
SmallVectorImpl<Type> &results) const {
+ assert(this && "expected non-null type converter");
+ assert(t && "expected non-null type");
+
{
std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
std::defer_lock);
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index d98a6a036e6b1f..2ca5f49637523f 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -124,10 +124,10 @@ func.func @no_remap_nested() {
// CHECK-NEXT: "foo.region"
// expected-remark at +1 {{op 'foo.region' is not legalizable}}
"foo.region"() ({
- // CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64):
- ^bb0(%i0: i64, %unused: i16, %i1: i64):
- // CHECK-NEXT: "test.valid"{{.*}} : (i64, i64)
- "test.invalid"(%i0, %i1) : (i64, i64) -> ()
+ // CHECK-NEXT: ^bb0(%{{.*}}: f64, %{{.*}}: i16, %{{.*}}: f64):
+ ^bb0(%i0: f64, %unused: i16, %i1: f64):
+ // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
+ "test.invalid"(%i0, %i1) : (f64, f64) -> ()
}) : () -> ()
// expected-remark at +1 {{op 'func.return' is not legalizable}}
return
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 8a0bc597c56beb..466ae7ff6f46f1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -979,8 +979,8 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
};
/// This pattern simply updates the operands of the given operation.
struct TestPassthroughInvalidOp : public ConversionPattern {
- TestPassthroughInvalidOp(MLIRContext *ctx)
- : ConversionPattern("test.invalid", 1, ctx) {}
+ TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter, "test.invalid", 1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -1301,19 +1301,19 @@ struct TestLegalizePatternDriver
TestTypeConverter converter;
mlir::RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
- patterns.add<
- TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
- TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
- TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
- TestSplitReturnType, TestChangeProducerTypeI32ToF32,
- TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
- TestUpdateConsumerType, TestNonRootReplacement,
- TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
- TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
- TestUndoPropertiesModification, TestEraseOp,
- TestRepetitive1ToNConsumer>(&getContext());
- patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
- &getContext(), converter);
+ patterns
+ .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
+ TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
+ TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
+ TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
+ TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+ TestNonRootReplacement, TestBoundedRecursiveRewrite,
+ TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
+ TestCreateUnregisteredOp, TestUndoMoveOpBefore,
+ TestUndoPropertiesModification, TestEraseOp,
+ TestRepetitive1ToNConsumer>(&getContext());
+ patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
+ TestPassthroughInvalidOp>(&getContext(), converter);
patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@@ -1749,8 +1749,9 @@ struct TestTypeConversionAnotherProducer
};
struct TestReplaceWithLegalOp : public ConversionPattern {
- TestReplaceWithLegalOp(MLIRContext *ctx)
- : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
+ TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
+ : ConversionPattern(converter, "test.replace_with_legal_op",
+ /*benefit=*/1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -1872,12 +1873,12 @@ struct TestTypeConversionDriver
// Initialize the set of rewrite patterns.
RewritePatternSet patterns(&getContext());
- patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
- TestSignatureConversionUndo,
- TestTestSignatureConversionNoConverter>(converter,
- &getContext());
- patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
- &getContext());
+ patterns
+ .add<TestTypeConsumerForward, TestTypeConversionProducer,
+ TestSignatureConversionUndo,
+ TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
+ converter, &getContext());
+ patterns.add<TestTypeConversionAnotherProducer>(&getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
More information about the Mlir-commits
mailing list