[Mlir-commits] [mlir] f1162fb - [mlir][spirv] Add folds for `BitwiseAnd` and `BitwiseOr`
Jakub Kuderski
llvmlistbot at llvm.org
Mon Sep 4 10:13:00 PDT 2023
Author: Jakub Kuderski
Date: 2023-09-04T13:12:40-04:00
New Revision: f1162fb677bf6641928a2390edac5a8d62cf1250
URL: https://github.com/llvm/llvm-project/commit/f1162fb677bf6641928a2390edac5a8d62cf1250
DIFF: https://github.com/llvm/llvm-project/commit/f1162fb677bf6641928a2390edac5a8d62cf1250.diff
LOG: [mlir][spirv] Add folds for `BitwiseAnd` and `BitwiseOr`
The newly folded patterns get produced by type emulation code, so we
cannot rely on the Arith folders handling these folds.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D159394
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index f20f113a7331f74..286f4de6f90f621 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -338,6 +338,8 @@ def SPIRV_BitwiseAndOp : SPIRV_BitBinaryOp<"BitwiseAnd",
%2 = spirv.BitwiseAnd %0, %1 : vector<4xi32>
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -373,6 +375,8 @@ def SPIRV_BitwiseOrOp : SPIRV_BitBinaryOp<"BitwiseOr",
%2 = spirv.BitwiseOr %0, %1 : vector<4xi32>
```
}];
+
+ let hasFolder = 1;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index fb522f942e2b8ad..f7ab3c0702a98b0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -35,6 +35,7 @@
#include "llvm/ADT/StringExtras.h"
#include <cassert>
#include <numeric>
+#include <optional>
#include <type_traits>
using namespace mlir;
@@ -1961,6 +1962,72 @@ LogicalResult spirv::ShiftRightLogicalOp::verify() {
return verifyShiftOp(*this);
}
+//===----------------------------------------------------------------------===//
+// spirv.BtiwiseAndOp
+//===----------------------------------------------------------------------===//
+
+static std::optional<APInt> extractIntConstant(Attribute attr) {
+ IntegerAttr intAttr;
+ if (auto splat = dyn_cast_if_present<SplatElementsAttr>(attr))
+ intAttr = dyn_cast<IntegerAttr>(splat.getSplatValue<Attribute>());
+ else
+ intAttr = dyn_cast_if_present<IntegerAttr>(attr);
+
+ if (!intAttr)
+ return std::nullopt;
+
+ return intAttr.getValue();
+}
+
+OpFoldResult
+spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
+ std::optional<APInt> rhsVal = extractIntConstant(adaptor.getOperand2());
+ if (!rhsVal)
+ return {};
+
+ APInt rhsMask = *rhsVal;
+
+ // x & 0 -> 0
+ if (rhsMask.isZero())
+ return getOperand2();
+
+ // x & <all ones> -> x
+ if (rhsMask.isAllOnes())
+ return getOperand1();
+
+ // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
+ if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
+ int valueBits =
+ getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth();
+ if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
+ return getOperand1();
+ }
+
+ return {};
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.BtiwiseOrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
+ std::optional<APInt> rhsVal = extractIntConstant(adaptor.getOperand2());
+ if (!rhsVal)
+ return {};
+
+ APInt rhsMask = *rhsVal;
+
+ // x | 0 -> x
+ if (rhsMask.isZero())
+ return getOperand1();
+
+ // x | <all ones> -> <all ones>
+ if (rhsMask.isAllOnes())
+ return getOperand2();
+
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// spirv.ImageQuerySize
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
index eeaa607b56040d7..82a2316f6c784fb 100644
--- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --verify-diagnostics --canonicalize %s \
+// RUN: | FileCheck %s --check-prefix=CANON
//===----------------------------------------------------------------------===//
// spirv.BitCount
@@ -82,18 +84,56 @@ func.func @bitreverse(%arg: i32) -> i32 {
// spirv.BitwiseOr
//===----------------------------------------------------------------------===//
+// CHECK-LABEL: func @bitwise_or_scalar
func.func @bitwise_or_scalar(%arg: i32) -> i32 {
// CHECK: spirv.BitwiseOr
%0 = spirv.BitwiseOr %arg, %arg : i32
return %0 : i32
}
+// CHECK-LABEL: func @bitwise_or_vector
func.func @bitwise_or_vector(%arg: vector<4xi32>) -> vector<4xi32> {
// CHECK: spirv.BitwiseOr
%0 = spirv.BitwiseOr %arg, %arg : vector<4xi32>
return %0 : vector<4xi32>
}
+// CANON-LABEL: func @bitwise_or_zero
+// CANON-SAME: (%[[ARG:.+]]: i32)
+func.func @bitwise_or_zero(%arg: i32) -> i32 {
+ // CANON: return %[[ARG]]
+ %zero = spirv.Constant 0 : i32
+ %0 = spirv.BitwiseOr %arg, %zero : i32
+ return %0 : i32
+}
+
+// CANON-LABEL: func @bitwise_or_zero_vector
+// CANON-SAME: (%[[ARG:.+]]: vector<4xi32>)
+func.func @bitwise_or_zero_vector(%arg: vector<4xi32>) -> vector<4xi32> {
+ // CANON: return %[[ARG]]
+ %zero = spirv.Constant dense<0> : vector<4xi32>
+ %0 = spirv.BitwiseOr %arg, %zero : vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// CANON-LABEL: func @bitwise_or_all_ones
+func.func @bitwise_or_all_ones(%arg: i8) -> i8 {
+ // CANON: %[[CST:.+]] = spirv.Constant -1
+ // CANON: return %[[CST]]
+ %ones = spirv.Constant 255 : i8
+ %0 = spirv.BitwiseOr %arg, %ones : i8
+ return %0 : i8
+}
+
+// CANON-LABEL: func @bitwise_or_all_ones_vector
+func.func @bitwise_or_all_ones_vector(%arg: vector<3xi8>) -> vector<3xi8> {
+ // CANON: %[[CST:.+]] = spirv.Constant dense<-1>
+ // CANON: return %[[CST]]
+ %ones = spirv.Constant dense<255> : vector<3xi8>
+ %0 = spirv.BitwiseOr %arg, %ones : vector<3xi8>
+ return %0 : vector<3xi8>
+}
+
// -----
func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 {
@@ -134,18 +174,101 @@ func.func @bitwise_xor_float(%arg0: f16, %arg1: f16) -> f16 {
// spirv.BitwiseAnd
//===----------------------------------------------------------------------===//
+// CHECK-LABEL: func @bitwise_and_scalar
func.func @bitwise_and_scalar(%arg: i32) -> i32 {
// CHECK: spirv.BitwiseAnd
%0 = spirv.BitwiseAnd %arg, %arg : i32
return %0 : i32
}
+// CHECK-LABEL: func @bitwise_and_vector
func.func @bitwise_and_vector(%arg: vector<4xi32>) -> vector<4xi32> {
// CHECK: spirv.BitwiseAnd
%0 = spirv.BitwiseAnd %arg, %arg : vector<4xi32>
return %0 : vector<4xi32>
}
+// CANON-LABEL: func @bitwise_and_zero
+func.func @bitwise_and_zero(%arg: i32) -> i32 {
+ // CANON: %[[CST:.+]] = spirv.Constant 0
+ // CANON: return %[[CST]]
+ %zero = spirv.Constant 0 : i32
+ %0 = spirv.BitwiseAnd %arg, %zero : i32
+ return %0 : i32
+}
+
+// CANON-LABEL: func @bitwise_and_zero_vector
+func.func @bitwise_and_zero_vector(%arg: vector<4xi32>) -> vector<4xi32> {
+ // CANON: %[[CST:.+]] = spirv.Constant dense<0>
+ // CANON: return %[[CST]]
+ %zero = spirv.Constant dense<0> : vector<4xi32>
+ %0 = spirv.BitwiseAnd %arg, %zero : vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// CANON-LABEL: func @bitwise_and_all_ones
+// CANON-SAME: (%[[ARG:.+]]: i8)
+func.func @bitwise_and_all_ones(%arg: i8) -> i8 {
+ // CANON: return %[[ARG]]
+ %ones = spirv.Constant 255 : i8
+ %0 = spirv.BitwiseAnd %arg, %ones : i8
+ return %0 : i8
+}
+
+// CANON-LABEL: func @bitwise_and_all_ones_vector
+// CANON-SAME: (%[[ARG:.+]]: vector<3xi8>)
+func.func @bitwise_and_all_ones_vector(%arg: vector<3xi8>) -> vector<3xi8> {
+ // CANON: return %[[ARG]]
+ %ones = spirv.Constant dense<255> : vector<3xi8>
+ %0 = spirv.BitwiseAnd %arg, %ones : vector<3xi8>
+ return %0 : vector<3xi8>
+}
+
+// CANON-LABEL: func @bitwise_and_zext_1
+// CANON-SAME: (%[[ARG:.+]]: i8)
+func.func @bitwise_and_zext_1(%arg: i8) -> i32 {
+ // CANON: %[[ZEXT:.+]] = spirv.UConvert %[[ARG]]
+ // CANON: return %[[ZEXT]]
+ %zext = spirv.UConvert %arg : i8 to i32
+ %ones = spirv.Constant 255 : i32
+ %0 = spirv.BitwiseAnd %zext, %ones : i32
+ return %0 : i32
+}
+
+// CANON-LABEL: func @bitwise_and_zext_2
+// CANON-SAME: (%[[ARG:.+]]: i8)
+func.func @bitwise_and_zext_2(%arg: i8) -> i32 {
+ // CANON: %[[ZEXT:.+]] = spirv.UConvert %[[ARG]]
+ // CANON: return %[[ZEXT]]
+ %zext = spirv.UConvert %arg : i8 to i32
+ %ones = spirv.Constant 0x12345ff : i32
+ %0 = spirv.BitwiseAnd %zext, %ones : i32
+ return %0 : i32
+}
+
+// CANON-LABEL: func @bitwise_and_zext_3
+// CANON-SAME: (%[[ARG:.+]]: i8)
+func.func @bitwise_and_zext_3(%arg: i8) -> i32 {
+ // CANON: %[[ZEXT:.+]] = spirv.UConvert %[[ARG]]
+ // CANON: %[[AND:.+]] = spirv.BitwiseAnd %[[ZEXT]]
+ // CANON: return %[[AND]]
+ %zext = spirv.UConvert %arg : i8 to i32
+ %ones = spirv.Constant 254 : i32
+ %0 = spirv.BitwiseAnd %zext, %ones : i32
+ return %0 : i32
+}
+
+// CANON-LABEL: func @bitwise_and_zext_vector
+// CANON-SAME: (%[[ARG:.+]]: vector<2xi8>)
+func.func @bitwise_and_zext_vector(%arg: vector<2xi8>) -> vector<2xi32> {
+ // CANON: %[[ZEXT:.+]] = spirv.UConvert %[[ARG]]
+ // CANON: return %[[ZEXT]]
+ %zext = spirv.UConvert %arg : vector<2xi8> to vector<2xi32>
+ %ones = spirv.Constant dense<255> : vector<2xi32>
+ %0 = spirv.BitwiseAnd %zext, %ones : vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
// -----
func.func @bitwise_and_float(%arg0: f16, %arg1: f16) -> f16 {
More information about the Mlir-commits
mailing list