[Mlir-commits] [mlir] add pattern for arith::UIToFPOp to VectorNarrowTypeRewritePatterns (PR #115485)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 11 05:14:40 PST 2024


https://github.com/ziereis updated https://github.com/llvm/llvm-project/pull/115485

>From 2e1e0f1350c6051e3eb09260546230305f8b87fc Mon Sep 17 00:00:00 2001
From: Thomas Ziereis <ziereis at roofline.ai>
Date: Fri, 8 Nov 2024 14:54:19 +0100
Subject: [PATCH 1/2] add uitofp rewrite

---
 .../Transforms/VectorEmulateNarrowType.cpp    |  6 ++--
 .../Vector/vector-rewrite-narrow-types.mlir   | 30 ++++++++++++++++++-
 2 files changed, 33 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 58841f29698e0d..76ddaa2df5a9d9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1452,8 +1452,10 @@ void vector::populateVectorNarrowTypeRewritePatterns(
                RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
                RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
                                               benefit.getBenefit() + 1);
-  patterns.add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>>(
-      patterns.getContext(), benefit.getBenefit() + 1);
+  patterns
+      .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
+           RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
+          patterns.getContext(), benefit.getBenefit() + 1);
 }
 
 void vector::populateVectorTransposeNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 84aaa9c61200b9..75e46a79600d08 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -262,6 +262,34 @@ func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
   return %0 : vector<8x32xf32>
 }
 
+// CHECK-LABEL: func.func @aligned_uitofp(
+func.func @aligned_uitofp(%a: vector<8xi4>) -> vector<8xf32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK:           %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
+  %0 = arith.uitofp %a : vector<8xi4> to vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func.func @aligned_uitofp_2d(
+func.func @aligned_uitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK:           %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
+  %0 = arith.uitofp %a : vector<8x32xi4> to vector<8x32xf32>
+  return %0 : vector<8x32xf32>
+}
+
 // CHECK-LABEL: func.func @aligned_trunci(
 func.func @aligned_trunci(%a: vector<8xi32>) -> vector<8xi4> {
 // CHECK-SAME:    %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> {
@@ -314,7 +342,7 @@ func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
   // CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[I4_MASK]] : vector<3x8x16xi8>
   // CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[LEFT_SHIFT_BITS]] : vector<3x8x16xi8>
   // CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<3x8x16xi8>
-  // CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<3x8x16xi8> to vector<3x8x32xi4> 
+  // CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<3x8x16xi8> to vector<3x8x32xi4>
   %0 = arith.trunci %a : vector<3x8x32xi32> to vector<3x8x32xi4>
   return %0 : vector<3x8x32xi4>
 }

>From ecbad796fc678bbdd269164fa9cf423a6dd2b88f Mon Sep 17 00:00:00 2001
From: Thomas Ziereis <ziereis at roofline.ai>
Date: Mon, 11 Nov 2024 14:14:26 +0100
Subject: [PATCH 2/2] renaming tests

---
 .../Vector/vector-rewrite-narrow-types.mlir   | 212 +++++++++---------
 1 file changed, 107 insertions(+), 105 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 75e46a79600d08..210025e30d7db5 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -193,36 +193,8 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
   return %1 : vector<8xi17>
 }
 
-// CHECK-LABEL: func.func @aligned_extsi(
-func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
-// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
-// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
-// CHECK:           %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
-  %0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
-  return %0 : vector<8xi32>
-}
-
-// CHECK-LABEL: func.func @aligned_extsi_2d(
-func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
-// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
-// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
-// CHECK:           %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
-  %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
-  return %0 : vector<8x32xi32>
-}
-
-// CHECK-LABEL: func.func @aligned_extsi_base_case(
-func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
+// CHECK-LABEL: func.func @aligned_extsi_i4_to_i8(
+func.func @aligned_extsi_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
 // CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
 // CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
 // CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
@@ -234,88 +206,61 @@ func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
   return %0 : vector<8xi8>
 }
 
-// CHECK-LABEL: func.func @aligned_sitofp(
-func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
+// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32(
+func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
 // CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
 // CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
 // CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
 // CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
 // CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
 // CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
-// CHECK:           %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
-  %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
-  return %0 : vector<8xf32>
+// CHECK:           %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+  %0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
+  return %0 : vector<8xi32>
 }
 
-// CHECK-LABEL: func.func @aligned_sitofp_2d(
-func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK-LABEL: func.func @aligned_extsi_2d(
+func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
 // CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
 // CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
 // CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
 // CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
 // CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
 // CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
-// CHECK:           %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
-  %0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32>
-  return %0 : vector<8x32xf32>
-}
-
-// CHECK-LABEL: func.func @aligned_uitofp(
-func.func @aligned_uitofp(%a: vector<8xi4>) -> vector<8xf32> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
-// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
-// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
-// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
-// CHECK:           %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
-  %0 = arith.uitofp %a : vector<8xi4> to vector<8xf32>
-  return %0 : vector<8xf32>
+// CHECK:           %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+  %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
+  return %0 : vector<8x32xi32>
 }
 
-// CHECK-LABEL: func.func @aligned_uitofp_2d(
-func.func @aligned_uitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
-// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
-// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
-// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
-// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
-// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
-// CHECK:           %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
-  %0 = arith.uitofp %a : vector<8x32xi4> to vector<8x32xf32>
-  return %0 : vector<8x32xf32>
-}
 
-// CHECK-LABEL: func.func @aligned_trunci(
-func.func @aligned_trunci(%a: vector<8xi32>) -> vector<8xi4> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> {
+// CHECK-LABEL: func.func @aligned_trunci_i8_to_i4(
+func.func @aligned_trunci_i8_to_i4(%a: vector<8xi8>) -> vector<8xi4> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> {
 // CHECK-DAG:       %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
 // CHECK-DAG:       %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK:           %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8>
-// CHECK:           %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<8xi8> -> vector<4xi8>
+// CHECK:           %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[IN]] : vector<8xi8> -> vector<4xi8>
 // CHECK:           %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
 // CHECK:           %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
 // CHECK:           %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
 // CHECK:           %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
-  %0 = arith.trunci %a : vector<8xi32> to vector<8xi4>
+  %0 = arith.trunci %a : vector<8xi8> to vector<8xi4>
   return %0 : vector<8xi4>
 }
 
-// CHECK-LABEL: func.func @aligned_trunci_base_case(
-func.func @aligned_trunci_base_case(%a: vector<8xi8>) -> vector<8xi4> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> {
+// CHECK-LABEL: func.func @aligned_trunci_i32_to_i4(
+func.func @aligned_trunci_i32_to_i4(%a: vector<8xi32>) -> vector<8xi4> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> {
 // CHECK-DAG:       %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
 // CHECK-DAG:       %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK:           %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[IN]] : vector<8xi8> -> vector<4xi8>
+// CHECK:           %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8>
+// CHECK:           %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<8xi8> -> vector<4xi8>
 // CHECK:           %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
 // CHECK:           %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
 // CHECK:           %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
 // CHECK:           %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
-  %0 = arith.trunci %a : vector<8xi8> to vector<8xi4>
+  %0 = arith.trunci %a : vector<8xi32> to vector<8xi4>
   return %0 : vector<8xi4>
 }
 
@@ -347,28 +292,21 @@ func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
   return %0 : vector<3x8x32xi4>
 }
 
-// CHECK-LABEL: func.func @i4_transpose(
-func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
-// CHECK:           %[[EXT:.*]] = vector.interleave
-// CHECK:           %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
-// CHECK:           vector.deinterleave %[[TRANS]] : vector<16x8xi8> -> vector<16x4xi8>
-  %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
-  return %0 : vector<16x8xi4>
-}
-
-// CHECK-LABEL: func.func @i7_transpose(
-func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
-// CHECK-SAME:    %[[IN:.*]]: vector<8x16xi7>) -> vector<16x8xi7> {
-// CHECK:           %[[EXT:.*]] = arith.extsi %[[IN]] : vector<8x16xi7> to vector<8x16xi8>
-// CHECK:           %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
-// CHECK:           %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
-  %0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7>
-  return %0 : vector<16x8xi7>
+// CHECK-LABEL: func.func @aligned_extui_i4_to_i8(
+func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
+// CHECK-SAME:                             %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+  %0 = arith.extui %a : vector<8xi4> to vector<8xi8>
+  return %0 : vector<8xi8>
 }
 
-// CHECK-LABEL: func.func @aligned_extui(
-func.func @aligned_extui(%a: vector<8xi4>) -> vector<8xi32> {
+// CHECK-LABEL: func.func @aligned_extui_i4_to_i32(
+func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
 // CHECK-SAME:                             %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
 // CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
 // CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
@@ -395,19 +333,83 @@ func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
   return %0 : vector<8x32xi32>
 }
 
-// CHECK-LABEL: func.func @aligned_extui_base_case(
-func.func @aligned_extui_base_case(%a: vector<8xi4>) -> vector<8xi8> {
-// CHECK-SAME:                             %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
+// CHECK-LABEL: func.func @aligned_sitofp(
+func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK:           %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
+  %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func.func @aligned_sitofp_2d(
+func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK:           %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
+  %0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32>
+  return %0 : vector<8x32xf32>
+}
+
+// CHECK-LABEL: func.func @aligned_uitofp(
+func.func @aligned_uitofp(%a: vector<8xi4>) -> vector<8xf32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
 // CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
 // CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
 // CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
 // CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
 // CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
 // CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
-  %0 = arith.extui %a : vector<8xi4> to vector<8xi8>
-  return %0 : vector<8xi8>
+// CHECK:           %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
+  %0 = arith.uitofp %a : vector<8xi4> to vector<8xf32>
+  return %0 : vector<8xf32>
 }
 
+// CHECK-LABEL: func.func @aligned_uitofp_2d(
+func.func @aligned_uitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK:           %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK:           %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
+  %0 = arith.uitofp %a : vector<8x32xi4> to vector<8x32xf32>
+  return %0 : vector<8x32xf32>
+}
+
+// CHECK-LABEL: func.func @i4_transpose(
+func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
+// CHECK:           %[[EXT:.*]] = vector.interleave
+// CHECK:           %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+// CHECK:           vector.deinterleave %[[TRANS]] : vector<16x8xi8> -> vector<16x4xi8>
+  %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
+  return %0 : vector<16x8xi4>
+}
+
+// CHECK-LABEL: func.func @i7_transpose(
+func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8x16xi7>) -> vector<16x8xi7> {
+// CHECK:           %[[EXT:.*]] = arith.extsi %[[IN]] : vector<8x16xi7> to vector<8x16xi8>
+// CHECK:           %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+// CHECK:           %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
+  %0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7>
+  return %0 : vector<16x8xi7>
+}
+
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op



More information about the Mlir-commits mailing list