[Mlir-commits] [mlir] ff96267 - [mlir][vector] Add folder for bitcast of integer splat constants

Quinn Dawkins llvmlistbot at llvm.org
Mon Jan 30 20:53:01 PST 2023


Author: Quinn Dawkins
Date: 2023-01-30T23:40:42-05:00
New Revision: ff96267b42021e3f0d886579e5405033a88b7222

URL: https://github.com/llvm/llvm-project/commit/ff96267b42021e3f0d886579e5405033a88b7222
DIFF: https://github.com/llvm/llvm-project/commit/ff96267b42021e3f0d886579e5405033a88b7222.diff

LOG: [mlir][vector] Add folder for bitcast of integer splat constants

This is a similar to the existing folder for f16 to f32 added with
D96041 but instead for integer types where destination bits > source bits.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D142922

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index cefd629127842..32ae7b1017e87 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4950,6 +4950,27 @@ OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
     }
   }
 
+  if (auto intPack = sourceConstant.dyn_cast<DenseIntElementsAttr>()) {
+    if (intPack.isSplat()) {
+      auto splat = intPack.getSplatValue<IntegerAttr>();
+
+      if (dstElemType.isa<IntegerType>()) {
+        uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
+        uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
+
+        // Casting to a larger integer bit width.
+        if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
+          APInt intBits = splat.getValue().zext(dstBitWidth);
+
+          // Duplicate the lower width element.
+          for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
+            intBits = (intBits << srcBitWidth) | intBits;
+          return DenseElementsAttr::get(getResultVectorType(), intBits);
+        }
+      }
+    }
+  }
+
   return {};
 }
 

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index fb890b6a0f44a..8fc1834ec6aaa 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -741,6 +741,20 @@ func.func @bitcast_f16_to_f32() -> (vector<4xf32>, vector<4xf32>) {
   return %cast0, %cast1: vector<4xf32>, vector<4xf32>
 }
 
+// CHECK-LABEL: func @bitcast_i8_to_i32
+//              bit pattern: 0xA0A0A0A0
+//       CHECK-DAG: %[[CST1:.+]] = arith.constant dense<-1600085856> : vector<4xi32>
+//              bit pattern: 0x00000000
+//       CHECK-DAG: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi32>
+//       CHECK: return %[[CST0]], %[[CST1]]
+func.func @bitcast_i8_to_i32() -> (vector<4xi32>, vector<4xi32>) {
+  %cst0 = arith.constant dense<0> : vector<16xi8> // bit pattern: 0x00
+  %cst1 = arith.constant dense<160> : vector<16xi8> // bit pattern: 0xA0
+  %cast0 = vector.bitcast %cst0: vector<16xi8> to vector<4xi32>
+  %cast1 = vector.bitcast %cst1: vector<16xi8> to vector<4xi32>
+  return %cast0, %cast1: vector<4xi32>, vector<4xi32>
+}
+
 // -----
 
 // CHECK-LABEL: broadcast_folding1


        


More information about the Mlir-commits mailing list