[Mlir-commits] [mlir] add pattern for arith::UIToFPOp to VectorNarrowTypeRewritePatterns (PR #115485)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 8 06:08:10 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (ziereis)

<details>
<summary>Changes</summary>

This pr just adds the patterns from https://github.com/llvm/llvm-project/pull/89131 for the arith::UIToFPOp.

---
Full diff: https://github.com/llvm/llvm-project/pull/115485.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+4-2) 
- (modified) mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir (+29-1) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 58841f29698e0d..76ddaa2df5a9d9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1452,8 +1452,10 @@ void vector::populateVectorNarrowTypeRewritePatterns(
                RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
                RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
                                               benefit.getBenefit() + 1);
-  patterns.add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>>(
-      patterns.getContext(), benefit.getBenefit() + 1);
+  patterns
+      .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
+           RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
+          patterns.getContext(), benefit.getBenefit() + 1);
 }
 
 void vector::populateVectorTransposeNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 84aaa9c61200b9..75e46a79600d08 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -262,6 +262,34 @@ func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
   return %0 : vector<8x32xf32>
 }
 
+// CHECK-LABEL: func.func @aligned_uitofp(
+func.func @aligned_uitofp(%a: vector<8xi4>) -> vector<8xf32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK:           %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
+  %0 = arith.uitofp %a : vector<8xi4> to vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func.func @aligned_uitofp_2d(
+func.func @aligned_uitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK:           %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
+  %0 = arith.uitofp %a : vector<8x32xi4> to vector<8x32xf32>
+  return %0 : vector<8x32xf32>
+}
+
 // CHECK-LABEL: func.func @aligned_trunci(
 func.func @aligned_trunci(%a: vector<8xi32>) -> vector<8xi4> {
 // CHECK-SAME:    %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> {
@@ -314,7 +342,7 @@ func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
   // CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[I4_MASK]] : vector<3x8x16xi8>
   // CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[LEFT_SHIFT_BITS]] : vector<3x8x16xi8>
   // CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<3x8x16xi8>
-  // CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<3x8x16xi8> to vector<3x8x32xi4> 
+  // CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<3x8x16xi8> to vector<3x8x32xi4>
   %0 = arith.trunci %a : vector<3x8x32xi32> to vector<3x8x32xi4>
   return %0 : vector<3x8x32xi4>
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/115485


More information about the Mlir-commits mailing list