[Mlir-commits] [mlir] [mlir][Vector] Fix integration test for vector.maskedload narrow type… (PR #71346)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 8 08:13:30 PST 2023


https://github.com/tyb0807 updated https://github.com/llvm/llvm-project/pull/71346

>From d957af2616768e1f72ce39e6ab5a3d50daa36da2 Mon Sep 17 00:00:00 2001
From: tyb0807 <vuson at google.com>
Date: Mon, 6 Nov 2023 02:09:21 +0000
Subject: [PATCH] [mlir][Vector] Fix integration test for vector.maskedload
 narrow type emulation pattern

Currently the test does not exercise the emulation pattern.
---
 .../Vector/TransformOps/VectorTransformOps.td | 31 ++++++++++
 .../TransformOps/VectorTransformOps.cpp       | 44 ++++++++++++++-
 .../Transforms/VectorEmulateNarrowType.cpp    |  2 -
 .../Vector/CPU/test-emulate-narrow-types.mlir | 56 +++++++++++++++++++
 .../Vector/CPU/test-rewrite-narrow-types.mlir | 23 --------
 5 files changed, 130 insertions(+), 26 deletions(-)
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/test-emulate-narrow-types.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 3ac6f28dcb93859..e3abffb972a2b41 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -15,6 +15,24 @@ include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
+def NarrowTypeEmulationConverterOp : Op<Transform_Dialect,
+    "apply_conversion_patterns.vector.emulate_narrow_type_converter",
+    [DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface>]> {
+  let description = [{
+    This operation provides a type converter that converts narrow integer or
+    float types that are not supported by the target hardware to wider types.
+
+    The type converter can be customized as follows:
+    - `load_store_emulate_bitwidth`: Bitwidth of target load/store emulation.
+    - `arith_compute_bitwidth`: Bitwidth of the arith computation.
+  }];
+
+  let arguments = (ins
+      DefaultValuedAttr<I64Attr, "8">:$load_store_emulate_bitwidth,
+      DefaultValuedAttr<I64Attr, "4">:$arith_compute_bitwidth);
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyVectorToLLVMConversionPatternsOp : Op<Transform_Dialect,
     "apply_conversion_patterns.vector.vector_to_llvm",
     [DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface,
@@ -292,6 +310,19 @@ def ApplyLowerTransposePatternsOp : Op<Transform_Dialect,
   }];
 }
 
+def ApplyEmulateNarrowTypesPatternsOp : Op<Transform_Dialect,
+    "apply_conversion_patterns.vector.emulate_narrow_types",
+    [DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector narrow type emulations should be applied.
+
+    This is usually a late step that is run after bufferization as part of the
+    process of lowering to e.g. LLVM or NVVM.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.rewrite_narrow_types",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 37127ea70f1e5af..a17bb2e1685aa34 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -11,6 +11,8 @@
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
@@ -45,6 +47,39 @@ transform::ApplyVectorToLLVMConversionPatternsOp::verifyTypeConverter(
   return success();
 }
 
+std::unique_ptr<TypeConverter>
+transform::NarrowTypeEmulationConverterOp::getTypeConverter() {
+  auto typeConverter = std::make_unique<arith::NarrowTypeEmulationConverter>(
+      getLoadStoreEmulateBitwidth());
+  int64_t arithComputeBitwidth = getArithComputeBitwidth();
+  // Convert scalar type.
+  typeConverter->addConversion(
+      [arithComputeBitwidth](IntegerType ty) -> std::optional<Type> {
+        unsigned width = ty.getWidth();
+        if (width >= arithComputeBitwidth)
+          return ty;
+
+        return IntegerType::get(ty.getContext(), arithComputeBitwidth);
+      });
+
+  // Convert vector type.
+  typeConverter->addConversion(
+      [arithComputeBitwidth](VectorType ty) -> std::optional<Type> {
+        auto intTy = dyn_cast<IntegerType>(ty.getElementType());
+        if (!intTy)
+          return ty;
+
+        unsigned width = intTy.getWidth();
+        if (width >= arithComputeBitwidth)
+          return ty;
+
+        return VectorType::get(
+            to_vector(ty.getShape()),
+            IntegerType::get(ty.getContext(), arithComputeBitwidth));
+      });
+  return typeConverter;
+}
+
 //===----------------------------------------------------------------------===//
 // Apply...PatternsOp
 //===----------------------------------------------------------------------===//
@@ -159,9 +194,16 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
   }
 }
 
+void transform::ApplyEmulateNarrowTypesPatternsOp::populatePatterns(
+    TypeConverter &typeConverter, RewritePatternSet &patterns) {
+  vector::populateVectorNarrowTypeEmulationPatterns(
+      static_cast<arith::NarrowTypeEmulationConverter &>(typeConverter),
+      patterns);
+}
+
 void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  populateVectorNarrowTypeRewritePatterns(patterns);
+  vector::populateVectorNarrowTypeRewritePatterns(patterns);
 }
 
 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 6aea0343bfc9327..c561374b67d280f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -180,7 +180,6 @@ struct ConvertVectorMaskedLoad final
   LogicalResult
   matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-
     auto loc = op.getLoc();
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getType().getElementType();
@@ -818,7 +817,6 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
 void vector::populateVectorNarrowTypeEmulationPatterns(
     arith::NarrowTypeEmulationConverter &typeConverter,
     RewritePatternSet &patterns) {
-
   // Populate `vector.*` conversion patterns.
   patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
                ConvertVectorTransferRead>(typeConverter, patterns.getContext());
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-emulate-narrow-types.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-emulate-narrow-types.mlir
new file mode 100644
index 000000000000000..c70a2da7ab0f85a
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-emulate-narrow-types.mlir
@@ -0,0 +1,56 @@
+/// Run once without applying the pattern and check the source of truth.
+// RUN: mlir-opt %s --test-transform-dialect-erase-schedule -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+/// Run once with the pattern and compare.
+// RUN: mlir-opt %s -transform-interpreter -test-transform-dialect-erase-schedule -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @fcst_maskedload(%A: memref<?xi4>, %passthru: vector<6xi4>) -> vector<6xi4> {
+  %c0 = arith.constant 0: index
+  %mask = vector.constant_mask [3] : vector<6xi1>
+  %1 = vector.maskedload %A[%c0], %mask, %passthru :
+    memref<?xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
+  return %1 : vector<6xi4>
+}
+
+func.func @entry() {
+  // Set up memory.
+  %c0 = arith.constant 0: index
+  %c1 = arith.constant 1: index
+  %c6 = arith.constant 6: index
+  %A = memref.alloc(%c6) : memref<?xi4>
+  scf.for %i = %c0 to %c6 step %c1 {
+    %i4 = arith.index_cast %i : index to i4
+    memref.store %i4, %A[%i] : memref<?xi4>
+  }
+  %passthru = arith.constant dense<[7, 8, 9, 10, 11, 12]> : vector<6xi4>
+  %load = call @fcst_maskedload(%A, %passthru) : (memref<?xi4>, vector<6xi4>) -> (vector<6xi4>)
+  vector.print %load : vector<6xi4>
+  // CHECK: ( 0, 1, 2, -6, -5, -4 )
+  memref.dealloc %A : memref<?xi4>
+
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %f = transform.structured.match ops{["func.func"]} in %module_op
+        : (!transform.any_op) -> !transform.any_op
+
+    transform.apply_conversion_patterns to %f {
+      transform.apply_conversion_patterns.vector.emulate_narrow_types
+    } with type_converter {
+      transform.apply_conversion_patterns.vector.emulate_narrow_type_converter
+      {arith_compute_bitwidth = 1,
+       load_store_emulate_bitwidth = 8}
+    } {
+      partial_conversion
+    }: !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
index a0b39a2b68f4388..711079518cea056 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
@@ -164,14 +164,6 @@ func.func @fext(%a: vector<5xi8>) {
   return
 }
 
-func.func @fcst_maskedload(%A: memref<?xi4>, %passthru: vector<6xi4>) -> vector<6xi4> {
-  %c0 = arith.constant 0: index
-  %mask = vector.constant_mask [3] : vector<6xi1>
-  %1 = vector.maskedload %A[%c0], %mask, %passthru :
-    memref<?xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
-  return %1 : vector<6xi4>
-}
-
 func.func @entry() {
   %v = arith.constant dense<[
     0xffff, 0xfffe, 0xfffd, 0xfffc, 0xfffb, 0xfffa, 0xfff9, 0xfff8,
@@ -194,21 +186,6 @@ func.func @entry() {
   ]> : vector<5xi8>
   func.call @fext(%v4) : (vector<5xi8>) -> ()
 
-  // Set up memory.
-  %c0 = arith.constant 0: index
-  %c1 = arith.constant 1: index
-  %c6 = arith.constant 6: index
-  %A = memref.alloc(%c6) : memref<?xi4>
-  scf.for %i = %c0 to %c6 step %c1 {
-    %i4 = arith.index_cast %i : index to i4
-    memref.store %i4, %A[%i] : memref<?xi4>
-  }
-  %passthru = arith.constant dense<[7, 8, 9, 10, 11, 12]> : vector<6xi4>
-  %load = call @fcst_maskedload(%A, %passthru) : (memref<?xi4>, vector<6xi4>) -> (vector<6xi4>)
-  vector.print %load : vector<6xi4>
-  // CHECK: ( 0, 1, 2, -6, -5, -4 )
-  memref.dealloc %A : memref<?xi4>
-
   return
 }
 



More information about the Mlir-commits mailing list