[Mlir-commits] [mlir] 69a3c9c - [mlir][spirv] Support sub-byte element vector extension/truncation

Lei Zhang llvmlistbot at llvm.org
Sat Aug 12 18:43:48 PDT 2023


Author: Lei Zhang
Date: 2023-08-12T18:43:01-07:00
New Revision: 69a3c9cddf1919b7caaa625efa1829cd5628bf10

URL: https://github.com/llvm/llvm-project/commit/69a3c9cddf1919b7caaa625efa1829cd5628bf10
DIFF: https://github.com/llvm/llvm-project/commit/69a3c9cddf1919b7caaa625efa1829cd5628bf10.diff

LOG: [mlir][spirv] Support sub-byte element vector extension/truncation

This would mean allowing vector type conversion involving sub-byte
element types.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D157791

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
    mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 01c694de08a9e8..39d6603a46f965 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -101,17 +101,9 @@ bool CompositeType::classof(Type type) {
 }
 
 bool CompositeType::isValid(VectorType type) {
-  switch (type.getNumElements()) {
-  case 2:
-  case 3:
-  case 4:
-  case 8:
-  case 16:
-    break;
-  default:
-    return false;
-  }
-  return type.getRank() == 1 && llvm::isa<ScalarType>(type.getElementType());
+  return type.getRank() == 1 &&
+         llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
+         llvm::isa<ScalarType>(type.getElementType());
 }
 
 Type CompositeType::getElementType(unsigned index) const {

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 7fea90a7dc8f42..94ce4ebb812947 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
@@ -303,16 +304,35 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
   type = cast<VectorType>(convertIndexElementType(type, options));
   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
   if (!scalarType) {
-    LLVM_DEBUG(llvm::dbgs()
-               << type << " illegal: cannot convert non-scalar element type\n");
-    return nullptr;
+    // If this is not a spec allowed scalar type, try to handle sub-byte integer
+    // types.
+    auto intType = dyn_cast<IntegerType>(type.getElementType());
+    if (!intType) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << type
+                 << " illegal: cannot convert non-scalar element type\n");
+      return nullptr;
+    }
+
+    Type elementType = convertSubByteIntegerType(options, intType);
+    if (type.getRank() <= 1 && type.getNumElements() == 1)
+      return elementType;
+
+    if (type.getNumElements() > 4) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << type << " illegal: > 4-element unimplemented\n");
+      return nullptr;
+    }
+
+    return VectorType::get(type.getShape(), elementType);
   }
 
   if (type.getRank() <= 1 && type.getNumElements() == 1)
     return convertScalarType(targetEnv, options, scalarType, storageClass);
 
   if (!spirv::CompositeType::isValid(type)) {
-    LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n");
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: not a valid composite type\n");
     return nullptr;
   }
 

diff  --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 604f85757537de..aa2cd649ecd789 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -990,9 +990,9 @@ func.func @fpext2(%arg0 : f32) -> f64 {
   return %0: f64
 }
 
-// CHECK-LABEL: @trunci4
+// CHECK-LABEL: @trunci4_scalar
 //  CHECK-SAME: %[[ARG:.*]]: i32
-func.func @trunci4(%arg0 : i32) -> i4 {
+func.func @trunci4_scalar(%arg0 : i32) -> i4 {
   // CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32
   // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[ARG]], %[[MASK]] : i32
   %0 = arith.trunci %arg0 : i32 to i4
@@ -1001,8 +1001,19 @@ func.func @trunci4(%arg0 : i32) -> i4 {
   return %0 : i4
 }
 
-// CHECK-LABEL: @zexti4
-func.func @zexti4(%arg0: i4) -> i32 {
+// CHECK-LABEL: @trunci4_vector
+//  CHECK-SAME: %[[ARG:.*]]: vector<2xi32>
+func.func @trunci4_vector(%arg0 : vector<2xi32>) -> vector<2xi4> {
+  // CHECK: %[[MASK:.+]] = spirv.Constant dense<15> : vector<2xi32>
+  // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[ARG]], %[[MASK]] : vector<2xi32>
+  %0 = arith.trunci %arg0 : vector<2xi32> to vector<2xi4>
+  // CHECK: %[[RET:.+]] = builtin.unrealized_conversion_cast %[[AND]] : vector<2xi32> to vector<2xi4>
+  // CHECK: return %[[RET]] : vector<2xi4>
+  return %0 : vector<2xi4>
+}
+
+// CHECK-LABEL: @zexti4_scalar
+func.func @zexti4_scalar(%arg0: i4) -> i32 {
   // CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %{{.+}} : i4 to i32
   // CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32
   // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[INPUT]], %[[MASK]] : i32
@@ -1011,8 +1022,18 @@ func.func @zexti4(%arg0: i4) -> i32 {
   return %0 : i32
 }
 
-// CHECK-LABEL: @sexti4
-func.func @sexti4(%arg0: i4) -> i32 {
+// CHECK-LABEL: @zexti4_vector
+func.func @zexti4_vector(%arg0: vector<3xi4>) -> vector<3xi32> {
+  // CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<3xi4> to vector<3xi32>
+  // CHECK: %[[MASK:.+]] = spirv.Constant dense<15> : vector<3xi32>
+  // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[INPUT]], %[[MASK]] : vector<3xi32>
+  %0 = arith.extui %arg0 : vector<3xi4> to vector<3xi32>
+  // CHECK: return %[[AND]] : vector<3xi32>
+  return %0 : vector<3xi32>
+}
+
+// CHECK-LABEL: @sexti4_scalar
+func.func @sexti4_scalar(%arg0: i4) -> i32 {
   // CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %arg0 : i4 to i32
   // CHECK: %[[SIZE:.+]] = spirv.Constant 28 : i32
   // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[INPUT]], %[[SIZE]] : i32, i32
@@ -1022,6 +1043,17 @@ func.func @sexti4(%arg0: i4) -> i32 {
   return %0 : i32
 }
 
+// CHECK-LABEL: @sexti4_vector
+func.func @sexti4_vector(%arg0: vector<4xi4>) -> vector<4xi32> {
+  // CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<4xi4> to vector<4xi32>
+  // CHECK: %[[SIZE:.+]] = spirv.Constant dense<28> : vector<4xi32>
+  // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[INPUT]], %[[SIZE]] : vector<4xi32>, vector<4xi32>
+  // CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[SL]], %[[SIZE]] : vector<4xi32>, vector<4xi32>
+  %0 = arith.extsi %arg0 : vector<4xi4> to vector<4xi32>
+  // CHECK: return %[[SR]] : vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
 } // end module
 
 // -----

diff  --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index ef1ee00b709fd4..82d750755ffe2e 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -105,6 +105,9 @@ func.func @integer2(%arg0: i8) { return }
 // CHECK: spirv.func @integer4(%{{.+}}: i32)
 func.func @integer4(%arg0: i4) { return }
 
+// CHECK: spirv.func @v3i4(%{{.+}}: vector<3xi32>)
+func.func @v3i4(%arg0: vector<3xi4>) { return }
+
 } // end module
 
 // -----


        


More information about the Mlir-commits mailing list