[Mlir-commits] [mlir] 69a3c9c - [mlir][spirv] Support sub-byte element vector extension/truncation
Lei Zhang
llvmlistbot at llvm.org
Sat Aug 12 18:43:48 PDT 2023
Author: Lei Zhang
Date: 2023-08-12T18:43:01-07:00
New Revision: 69a3c9cddf1919b7caaa625efa1829cd5628bf10
URL: https://github.com/llvm/llvm-project/commit/69a3c9cddf1919b7caaa625efa1829cd5628bf10
DIFF: https://github.com/llvm/llvm-project/commit/69a3c9cddf1919b7caaa625efa1829cd5628bf10.diff
LOG: [mlir][spirv] Support sub-byte element vector extension/truncation
This would mean allowing vector type conversion involving sub-byte
element types.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D157791
Added:
Modified:
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 01c694de08a9e8..39d6603a46f965 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -101,17 +101,9 @@ bool CompositeType::classof(Type type) {
}
bool CompositeType::isValid(VectorType type) {
- switch (type.getNumElements()) {
- case 2:
- case 3:
- case 4:
- case 8:
- case 16:
- break;
- default:
- return false;
- }
- return type.getRank() == 1 && llvm::isa<ScalarType>(type.getElementType());
+ return type.getRank() == 1 &&
+ llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
+ llvm::isa<ScalarType>(type.getElementType());
}
Type CompositeType::getElementType(unsigned index) const {
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 7fea90a7dc8f42..94ce4ebb812947 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
@@ -303,16 +304,35 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
type = cast<VectorType>(convertIndexElementType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
- LLVM_DEBUG(llvm::dbgs()
- << type << " illegal: cannot convert non-scalar element type\n");
- return nullptr;
+ // If this is not a spec allowed scalar type, try to handle sub-byte integer
+ // types.
+ auto intType = dyn_cast<IntegerType>(type.getElementType());
+ if (!intType) {
+ LLVM_DEBUG(llvm::dbgs()
+ << type
+ << " illegal: cannot convert non-scalar element type\n");
+ return nullptr;
+ }
+
+ Type elementType = convertSubByteIntegerType(options, intType);
+ if (type.getRank() <= 1 && type.getNumElements() == 1)
+ return elementType;
+
+ if (type.getNumElements() > 4) {
+ LLVM_DEBUG(llvm::dbgs()
+ << type << " illegal: > 4-element unimplemented\n");
+ return nullptr;
+ }
+
+ return VectorType::get(type.getShape(), elementType);
}
if (type.getRank() <= 1 && type.getNumElements() == 1)
return convertScalarType(targetEnv, options, scalarType, storageClass);
if (!spirv::CompositeType::isValid(type)) {
- LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n");
+ LLVM_DEBUG(llvm::dbgs()
+ << type << " illegal: not a valid composite type\n");
return nullptr;
}
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 604f85757537de..aa2cd649ecd789 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -990,9 +990,9 @@ func.func @fpext2(%arg0 : f32) -> f64 {
return %0: f64
}
-// CHECK-LABEL: @trunci4
+// CHECK-LABEL: @trunci4_scalar
// CHECK-SAME: %[[ARG:.*]]: i32
-func.func @trunci4(%arg0 : i32) -> i4 {
+func.func @trunci4_scalar(%arg0 : i32) -> i4 {
// CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32
// CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[ARG]], %[[MASK]] : i32
%0 = arith.trunci %arg0 : i32 to i4
@@ -1001,8 +1001,19 @@ func.func @trunci4(%arg0 : i32) -> i4 {
return %0 : i4
}
-// CHECK-LABEL: @zexti4
-func.func @zexti4(%arg0: i4) -> i32 {
+// CHECK-LABEL: @trunci4_vector
+// CHECK-SAME: %[[ARG:.*]]: vector<2xi32>
+func.func @trunci4_vector(%arg0 : vector<2xi32>) -> vector<2xi4> {
+ // CHECK: %[[MASK:.+]] = spirv.Constant dense<15> : vector<2xi32>
+ // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[ARG]], %[[MASK]] : vector<2xi32>
+ %0 = arith.trunci %arg0 : vector<2xi32> to vector<2xi4>
+ // CHECK: %[[RET:.+]] = builtin.unrealized_conversion_cast %[[AND]] : vector<2xi32> to vector<2xi4>
+ // CHECK: return %[[RET]] : vector<2xi4>
+ return %0 : vector<2xi4>
+}
+
+// CHECK-LABEL: @zexti4_scalar
+func.func @zexti4_scalar(%arg0: i4) -> i32 {
// CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %{{.+}} : i4 to i32
// CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32
// CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[INPUT]], %[[MASK]] : i32
@@ -1011,8 +1022,18 @@ func.func @zexti4(%arg0: i4) -> i32 {
return %0 : i32
}
-// CHECK-LABEL: @sexti4
-func.func @sexti4(%arg0: i4) -> i32 {
+// CHECK-LABEL: @zexti4_vector
+func.func @zexti4_vector(%arg0: vector<3xi4>) -> vector<3xi32> {
+ // CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<3xi4> to vector<3xi32>
+ // CHECK: %[[MASK:.+]] = spirv.Constant dense<15> : vector<3xi32>
+ // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[INPUT]], %[[MASK]] : vector<3xi32>
+ %0 = arith.extui %arg0 : vector<3xi4> to vector<3xi32>
+ // CHECK: return %[[AND]] : vector<3xi32>
+ return %0 : vector<3xi32>
+}
+
+// CHECK-LABEL: @sexti4_scalar
+func.func @sexti4_scalar(%arg0: i4) -> i32 {
// CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %arg0 : i4 to i32
// CHECK: %[[SIZE:.+]] = spirv.Constant 28 : i32
// CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[INPUT]], %[[SIZE]] : i32, i32
@@ -1022,6 +1043,17 @@ func.func @sexti4(%arg0: i4) -> i32 {
return %0 : i32
}
+// CHECK-LABEL: @sexti4_vector
+func.func @sexti4_vector(%arg0: vector<4xi4>) -> vector<4xi32> {
+ // CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<4xi4> to vector<4xi32>
+ // CHECK: %[[SIZE:.+]] = spirv.Constant dense<28> : vector<4xi32>
+ // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[INPUT]], %[[SIZE]] : vector<4xi32>, vector<4xi32>
+ // CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[SL]], %[[SIZE]] : vector<4xi32>, vector<4xi32>
+ %0 = arith.extsi %arg0 : vector<4xi4> to vector<4xi32>
+ // CHECK: return %[[SR]] : vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
} // end module
// -----
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index ef1ee00b709fd4..82d750755ffe2e 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -105,6 +105,9 @@ func.func @integer2(%arg0: i8) { return }
// CHECK: spirv.func @integer4(%{{.+}}: i32)
func.func @integer4(%arg0: i4) { return }
+// CHECK: spirv.func @v3i4(%{{.+}}: vector<3xi32>)
+func.func @v3i4(%arg0: vector<3xi4>) { return }
+
} // end module
// -----
More information about the Mlir-commits
mailing list