[Mlir-commits] [mlir] [mlir][Vector] Fix integration test for vector.maskedload narrow type… (PR #71346)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Nov 5 18:14:46 PST 2023
https://github.com/tyb0807 created https://github.com/llvm/llvm-project/pull/71346
… emulation pattern
Currently the test does not exercise the emulation pattern.
>From 460c14b8da6a5e1306180f0cb9e1d7daf8bb6126 Mon Sep 17 00:00:00 2001
From: tyb0807 <vuson at google.com>
Date: Mon, 6 Nov 2023 02:09:21 +0000
Subject: [PATCH] [mlir][Vector] Fix integration test for vector.maskedload
narrow type emulation pattern
Currently the test does not exercise the emulation pattern.
---
.../Vector/TransformOps/VectorTransformOps.td | 31 ++++++++++
.../TransformOps/VectorTransformOps.cpp | 40 ++++++++++++-
.../Transforms/VectorEmulateNarrowType.cpp | 4 +-
.../Vector/CPU/test-emulate-narrow-types.mlir | 57 +++++++++++++++++++
.../Vector/CPU/test-rewrite-narrow-types.mlir | 23 --------
5 files changed, 129 insertions(+), 26 deletions(-)
create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/test-emulate-narrow-types.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 3ac6f28dcb93859..e3abffb972a2b41 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -15,6 +15,24 @@ include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
+def NarrowTypeEmulationConverterOp : Op<Transform_Dialect,
+ "apply_conversion_patterns.vector.emulate_narrow_type_converter",
+ [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface>]> {
+ let description = [{
+ This operation provides a type converter that converts narrow integer or
+ float types that are not supported by the target hardware to wider types.
+
+ The type converter can be customized as follows:
+ - `load_store_emulate_bitwidth`: Bitwidth of target load/store emulation.
+ - `arith_compute_bitwidth`: Bitwidth of the arith computation.
+ }];
+
+ let arguments = (ins
+ DefaultValuedAttr<I64Attr, "8">:$load_store_emulate_bitwidth,
+ DefaultValuedAttr<I64Attr, "4">:$arith_compute_bitwidth);
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyVectorToLLVMConversionPatternsOp : Op<Transform_Dialect,
"apply_conversion_patterns.vector.vector_to_llvm",
[DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface,
@@ -292,6 +310,19 @@ def ApplyLowerTransposePatternsOp : Op<Transform_Dialect,
}];
}
+def ApplyEmulateNarrowTypesPatternsOp : Op<Transform_Dialect,
+ "apply_conversion_patterns.vector.emulate_narrow_types",
+ [DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector narrow type emulations should be applied.
+
+ This is usually a late step that is run after bufferization as part of the
+ process of lowering to e.g. LLVM or NVVM.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.rewrite_narrow_types",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 37127ea70f1e5af..30f68fefe7287c1 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -11,6 +11,8 @@
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
@@ -45,6 +47,36 @@ transform::ApplyVectorToLLVMConversionPatternsOp::verifyTypeConverter(
return success();
}
+std::unique_ptr<TypeConverter>
+transform::NarrowTypeEmulationConverterOp::getTypeConverter() {
+ auto typeConverter = std::make_unique<arith::NarrowTypeEmulationConverter>(getLoadStoreEmulateBitwidth());
+ // Convert scalar type.
+ typeConverter->addConversion([this](IntegerType ty) -> std::optional<Type> {
+ unsigned width = ty.getWidth();
+ if (width >= getArithComputeBitwidth())
+ return ty;
+
+ return IntegerType::get(ty.getContext(), getArithComputeBitwidth());
+ });
+
+ // Convert vector type.
+ typeConverter->addConversion([this](VectorType ty) -> std::optional<Type> {
+ auto intTy = dyn_cast<IntegerType>(ty.getElementType());
+ if (!intTy)
+ return ty;
+
+ unsigned width = intTy.getWidth();
+ if (width >= getArithComputeBitwidth())
+ return ty;
+
+ return VectorType::get(
+ to_vector(ty.getShape()),
+ IntegerType::get(ty.getContext(), getArithComputeBitwidth()));
+ });
+ llvm::errs() << "********************* " << getLoadStoreEmulateBitwidth() << ' ' << getArithComputeBitwidth() << " OK TypeConverter\n";
+ return typeConverter;
+}
+
//===----------------------------------------------------------------------===//
// Apply...PatternsOp
//===----------------------------------------------------------------------===//
@@ -159,9 +191,15 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
}
}
+void transform::ApplyEmulateNarrowTypesPatternsOp::populatePatterns(
+ TypeConverter &typeConverter, RewritePatternSet &patterns) {
+ vector::populateVectorNarrowTypeEmulationPatterns(static_cast<arith::NarrowTypeEmulationConverter &>(typeConverter), patterns);
+ llvm::errs() << "********************* OK TD\n";
+}
+
void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- populateVectorNarrowTypeRewritePatterns(patterns);
+ vector::populateVectorNarrowTypeRewritePatterns(patterns);
}
void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 3d65123373109b3..976730ca0ead4f5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -115,7 +115,7 @@ struct ConvertVectorMaskedLoad final
LogicalResult
matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
-
+ llvm::errs() << "***************************** OK2\n";
auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
@@ -753,10 +753,10 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
void vector::populateVectorNarrowTypeEmulationPatterns(
arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns) {
-
// Populate `vector.*` conversion patterns.
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
ConvertVectorTransferRead>(typeConverter, patterns.getContext());
+ llvm::errs() << "********************* OK\n";
}
void vector::populateVectorNarrowTypeRewritePatterns(
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-emulate-narrow-types.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-emulate-narrow-types.mlir
new file mode 100644
index 000000000000000..53ed88dfa52b187
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-emulate-narrow-types.mlir
@@ -0,0 +1,57 @@
+/// Run once without applying the pattern and check the source of truth.
+// RUN: mlir-opt %s --test-transform-dialect-erase-schedule -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+/// Run once with the pattern and compare.
+// RUN: mlir-opt %s -transform-interpreter -test-transform-dialect-erase-schedule -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @fcst_maskedload(%A: memref<?xi4>, %passthru: vector<6xi4>) -> vector<6xi4> {
+ %c0 = arith.constant 0: index
+ %mask = vector.constant_mask [3] : vector<6xi1>
+ %1 = vector.maskedload %A[%c0], %mask, %passthru :
+ memref<?xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
+ return %1 : vector<6xi4>
+}
+
+func.func @entry() {
+ // Set up memory.
+ %c0 = arith.constant 0: index
+ %c1 = arith.constant 1: index
+ %c6 = arith.constant 6: index
+ %A = memref.alloc(%c6) : memref<?xi4>
+ scf.for %i = %c0 to %c6 step %c1 {
+ %i4 = arith.index_cast %i : index to i4
+ memref.store %i4, %A[%i] : memref<?xi4>
+ }
+ %passthru = arith.constant dense<[7, 8, 9, 10, 11, 12]> : vector<6xi4>
+ %load = call @fcst_maskedload(%A, %passthru) : (memref<?xi4>, vector<6xi4>) -> (vector<6xi4>)
+ vector.print %load : vector<6xi4>
+ // CHECK: ( 0, 1, 2, -6, -5, -4 )
+ memref.dealloc %A : memref<?xi4>
+
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_conversion_patterns to %f {
+ transform.apply_conversion_patterns.vector.emulate_narrow_types
+ } with type_converter {
+ transform.apply_conversion_patterns.vector.emulate_narrow_type_converter
+ {arith_compute_bitwidth = 1,
+ load_store_emulate_bitwidth = 8}
+ } {
+ legal_dialects = ["vector"],
+ partial_conversion
+ }: !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
index a0b39a2b68f4388..711079518cea056 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
@@ -164,14 +164,6 @@ func.func @fext(%a: vector<5xi8>) {
return
}
-func.func @fcst_maskedload(%A: memref<?xi4>, %passthru: vector<6xi4>) -> vector<6xi4> {
- %c0 = arith.constant 0: index
- %mask = vector.constant_mask [3] : vector<6xi1>
- %1 = vector.maskedload %A[%c0], %mask, %passthru :
- memref<?xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
- return %1 : vector<6xi4>
-}
-
func.func @entry() {
%v = arith.constant dense<[
0xffff, 0xfffe, 0xfffd, 0xfffc, 0xfffb, 0xfffa, 0xfff9, 0xfff8,
@@ -194,21 +186,6 @@ func.func @entry() {
]> : vector<5xi8>
func.call @fext(%v4) : (vector<5xi8>) -> ()
- // Set up memory.
- %c0 = arith.constant 0: index
- %c1 = arith.constant 1: index
- %c6 = arith.constant 6: index
- %A = memref.alloc(%c6) : memref<?xi4>
- scf.for %i = %c0 to %c6 step %c1 {
- %i4 = arith.index_cast %i : index to i4
- memref.store %i4, %A[%i] : memref<?xi4>
- }
- %passthru = arith.constant dense<[7, 8, 9, 10, 11, 12]> : vector<6xi4>
- %load = call @fcst_maskedload(%A, %passthru) : (memref<?xi4>, vector<6xi4>) -> (vector<6xi4>)
- vector.print %load : vector<6xi4>
- // CHECK: ( 0, 1, 2, -6, -5, -4 )
- memref.dealloc %A : memref<?xi4>
-
return
}
More information about the Mlir-commits
mailing list