[Mlir-commits] [mlir] f1162fb - [mlir][spirv] Add folds for `BitwiseAnd` and `BitwiseOr`

Jakub Kuderski llvmlistbot at llvm.org
Mon Sep 4 10:13:00 PDT 2023


Author: Jakub Kuderski
Date: 2023-09-04T13:12:40-04:00
New Revision: f1162fb677bf6641928a2390edac5a8d62cf1250

URL: https://github.com/llvm/llvm-project/commit/f1162fb677bf6641928a2390edac5a8d62cf1250
DIFF: https://github.com/llvm/llvm-project/commit/f1162fb677bf6641928a2390edac5a8d62cf1250.diff

LOG: [mlir][spirv] Add folds for `BitwiseAnd` and `BitwiseOr`

The newly folded patterns get produced by type emulation code, so we
cannot rely on the Arith folders handling these folds.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D159394

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/IR/bit-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index f20f113a7331f74..286f4de6f90f621 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -338,6 +338,8 @@ def SPIRV_BitwiseAndOp : SPIRV_BitBinaryOp<"BitwiseAnd",
     %2 = spirv.BitwiseAnd %0, %1 : vector<4xi32>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -373,6 +375,8 @@ def SPIRV_BitwiseOrOp : SPIRV_BitBinaryOp<"BitwiseOr",
     %2 = spirv.BitwiseOr %0, %1 : vector<4xi32>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index fb522f942e2b8ad..f7ab3c0702a98b0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -35,6 +35,7 @@
 #include "llvm/ADT/StringExtras.h"
 #include <cassert>
 #include <numeric>
+#include <optional>
 #include <type_traits>
 
 using namespace mlir;
@@ -1961,6 +1962,72 @@ LogicalResult spirv::ShiftRightLogicalOp::verify() {
   return verifyShiftOp(*this);
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.BtiwiseAndOp
+//===----------------------------------------------------------------------===//
+
+static std::optional<APInt> extractIntConstant(Attribute attr) {
+  IntegerAttr intAttr;
+  if (auto splat = dyn_cast_if_present<SplatElementsAttr>(attr))
+    intAttr = dyn_cast<IntegerAttr>(splat.getSplatValue<Attribute>());
+  else
+    intAttr = dyn_cast_if_present<IntegerAttr>(attr);
+
+  if (!intAttr)
+    return std::nullopt;
+
+  return intAttr.getValue();
+}
+
+OpFoldResult
+spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
+  std::optional<APInt> rhsVal = extractIntConstant(adaptor.getOperand2());
+  if (!rhsVal)
+    return {};
+
+  APInt rhsMask = *rhsVal;
+
+  // x & 0 -> 0
+  if (rhsMask.isZero())
+    return getOperand2();
+
+  // x & <all ones> -> x
+  if (rhsMask.isAllOnes())
+    return getOperand1();
+
+  // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
+  if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
+    int valueBits =
+        getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth();
+    if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
+      return getOperand1();
+  }
+
+  return {};
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.BtiwiseOrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
+  std::optional<APInt> rhsVal = extractIntConstant(adaptor.getOperand2());
+  if (!rhsVal)
+    return {};
+
+  APInt rhsMask = *rhsVal;
+
+  // x | 0 -> x
+  if (rhsMask.isZero())
+    return getOperand1();
+
+  // x | <all ones> -> <all ones>
+  if (rhsMask.isAllOnes())
+    return getOperand2();
+
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.ImageQuerySize
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
index eeaa607b56040d7..82a2316f6c784fb 100644
--- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --verify-diagnostics --canonicalize %s \
+// RUN:  | FileCheck %s --check-prefix=CANON
 
 //===----------------------------------------------------------------------===//
 // spirv.BitCount
@@ -82,18 +84,56 @@ func.func @bitreverse(%arg: i32) -> i32 {
 // spirv.BitwiseOr
 //===----------------------------------------------------------------------===//
 
+// CHECK-LABEL: func @bitwise_or_scalar
 func.func @bitwise_or_scalar(%arg: i32) -> i32 {
   // CHECK: spirv.BitwiseOr
   %0 = spirv.BitwiseOr %arg, %arg : i32
   return %0 : i32
 }
 
+// CHECK-LABEL: func @bitwise_or_vector
 func.func @bitwise_or_vector(%arg: vector<4xi32>) -> vector<4xi32> {
   // CHECK: spirv.BitwiseOr
   %0 = spirv.BitwiseOr %arg, %arg : vector<4xi32>
   return %0 : vector<4xi32>
 }
 
+// CANON-LABEL: func @bitwise_or_zero
+// CANON-SAME:    (%[[ARG:.+]]: i32)
+func.func @bitwise_or_zero(%arg: i32) -> i32 {
+  // CANON: return %[[ARG]]
+  %zero = spirv.Constant 0 : i32
+  %0 = spirv.BitwiseOr %arg, %zero : i32
+  return %0 : i32
+}
+
+// CANON-LABEL: func @bitwise_or_zero_vector
+// CANON-SAME:    (%[[ARG:.+]]: vector<4xi32>)
+func.func @bitwise_or_zero_vector(%arg: vector<4xi32>) -> vector<4xi32> {
+  // CANON: return %[[ARG]]
+  %zero = spirv.Constant dense<0> : vector<4xi32>
+  %0 = spirv.BitwiseOr %arg, %zero : vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// CANON-LABEL: func @bitwise_or_all_ones
+func.func @bitwise_or_all_ones(%arg: i8) -> i8 {
+  // CANON: %[[CST:.+]] = spirv.Constant -1
+  // CANON: return %[[CST]]
+  %ones = spirv.Constant 255 : i8
+  %0 = spirv.BitwiseOr %arg, %ones : i8
+  return %0 : i8
+}
+
+// CANON-LABEL: func @bitwise_or_all_ones_vector
+func.func @bitwise_or_all_ones_vector(%arg: vector<3xi8>) -> vector<3xi8> {
+  // CANON: %[[CST:.+]] = spirv.Constant dense<-1>
+  // CANON: return %[[CST]]
+  %ones = spirv.Constant dense<255> : vector<3xi8>
+  %0 = spirv.BitwiseOr %arg, %ones : vector<3xi8>
+  return %0 : vector<3xi8>
+}
+
 // -----
 
 func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 {
@@ -134,18 +174,101 @@ func.func @bitwise_xor_float(%arg0: f16, %arg1: f16) -> f16 {
 // spirv.BitwiseAnd
 //===----------------------------------------------------------------------===//
 
+// CHECK-LABEL: func @bitwise_and_scalar
 func.func @bitwise_and_scalar(%arg: i32) -> i32 {
   // CHECK: spirv.BitwiseAnd
   %0 = spirv.BitwiseAnd %arg, %arg : i32
   return %0 : i32
 }
 
+// CHECK-LABEL: func @bitwise_and_vector
 func.func @bitwise_and_vector(%arg: vector<4xi32>) -> vector<4xi32> {
   // CHECK: spirv.BitwiseAnd
   %0 = spirv.BitwiseAnd %arg, %arg : vector<4xi32>
   return %0 : vector<4xi32>
 }
 
+// CANON-LABEL: func @bitwise_and_zero
+func.func @bitwise_and_zero(%arg: i32) -> i32 {
+  // CANON: %[[CST:.+]] = spirv.Constant 0
+  // CANON: return %[[CST]]
+  %zero = spirv.Constant 0 : i32
+  %0 = spirv.BitwiseAnd %arg, %zero : i32
+  return %0 : i32
+}
+
+// CANON-LABEL: func @bitwise_and_zero_vector
+func.func @bitwise_and_zero_vector(%arg: vector<4xi32>) -> vector<4xi32> {
+  // CANON: %[[CST:.+]] = spirv.Constant dense<0>
+  // CANON: return %[[CST]]
+  %zero = spirv.Constant dense<0> : vector<4xi32>
+  %0 = spirv.BitwiseAnd %arg, %zero : vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// CANON-LABEL: func @bitwise_and_all_ones
+// CANON-SAME:    (%[[ARG:.+]]: i8)
+func.func @bitwise_and_all_ones(%arg: i8) -> i8 {
+  // CANON: return %[[ARG]]
+  %ones = spirv.Constant 255 : i8
+  %0 = spirv.BitwiseAnd %arg, %ones : i8
+  return %0 : i8
+}
+
+// CANON-LABEL: func @bitwise_and_all_ones_vector
+// CANON-SAME:    (%[[ARG:.+]]: vector<3xi8>)
+func.func @bitwise_and_all_ones_vector(%arg: vector<3xi8>) -> vector<3xi8> {
+  // CANON: return %[[ARG]]
+  %ones = spirv.Constant dense<255> : vector<3xi8>
+  %0 = spirv.BitwiseAnd %arg, %ones : vector<3xi8>
+  return %0 : vector<3xi8>
+}
+
+// CANON-LABEL: func @bitwise_and_zext_1
+// CANON-SAME:    (%[[ARG:.+]]: i8)
+func.func @bitwise_and_zext_1(%arg: i8) -> i32 {
+  // CANON: %[[ZEXT:.+]] = spirv.UConvert %[[ARG]]
+  // CANON: return %[[ZEXT]]
+  %zext = spirv.UConvert %arg : i8 to i32
+  %ones = spirv.Constant 255 : i32
+  %0 = spirv.BitwiseAnd %zext, %ones : i32
+  return %0 : i32
+}
+
+// CANON-LABEL: func @bitwise_and_zext_2
+// CANON-SAME:    (%[[ARG:.+]]: i8)
+func.func @bitwise_and_zext_2(%arg: i8) -> i32 {
+  // CANON: %[[ZEXT:.+]] = spirv.UConvert %[[ARG]]
+  // CANON: return %[[ZEXT]]
+  %zext = spirv.UConvert %arg : i8 to i32
+  %ones = spirv.Constant 0x12345ff : i32
+  %0 = spirv.BitwiseAnd %zext, %ones : i32
+  return %0 : i32
+}
+
+// CANON-LABEL: func @bitwise_and_zext_3
+// CANON-SAME:    (%[[ARG:.+]]: i8)
+func.func @bitwise_and_zext_3(%arg: i8) -> i32 {
+  // CANON: %[[ZEXT:.+]] = spirv.UConvert %[[ARG]]
+  // CANON: %[[AND:.+]]  = spirv.BitwiseAnd %[[ZEXT]]
+  // CANON: return %[[AND]]
+  %zext = spirv.UConvert %arg : i8 to i32
+  %ones = spirv.Constant 254 : i32
+  %0 = spirv.BitwiseAnd %zext, %ones : i32
+  return %0 : i32
+}
+
+// CANON-LABEL: func @bitwise_and_zext_vector
+// CANON-SAME:    (%[[ARG:.+]]: vector<2xi8>)
+func.func @bitwise_and_zext_vector(%arg: vector<2xi8>) -> vector<2xi32> {
+  // CANON: %[[ZEXT:.+]] = spirv.UConvert %[[ARG]]
+  // CANON: return %[[ZEXT]]
+  %zext = spirv.UConvert %arg : vector<2xi8> to vector<2xi32>
+  %ones = spirv.Constant dense<255> : vector<2xi32>
+  %0 = spirv.BitwiseAnd %zext, %ones : vector<2xi32>
+  return %0 : vector<2xi32>
+}
+
 // -----
 
 func.func @bitwise_and_float(%arg0: f16, %arg1: f16) -> f16 {


        


More information about the Mlir-commits mailing list