[Mlir-commits] [mlir] 141b7d4 - [mlir][spirv] Fix UnifyAliasedResourcePass for 64-bit index
Lei Zhang
llvmlistbot at llvm.org
Tue Mar 14 16:54:36 PDT 2023
Author: Lei Zhang
Date: 2023-03-14T23:54:27Z
New Revision: 141b7d49a3a4365c52f12ab6e38664cebce538b5
URL: https://github.com/llvm/llvm-project/commit/141b7d49a3a4365c52f12ab6e38664cebce538b5
DIFF: https://github.com/llvm/llvm-project/commit/141b7d49a3a4365c52f12ab6e38664cebce538b5.diff
LOG: [mlir][spirv] Fix UnifyAliasedResourcePass for 64-bit index
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D145079
Added:
Modified:
mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
index 3acfba23ddfd1..1713c4490f960 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
@@ -366,7 +366,6 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
}
Location loc = acOp.getLoc();
- auto i32Type = rewriter.getI32Type();
if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) {
// The source indices are for a buffer with scalar element types. Rewrite
@@ -376,16 +375,19 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
int srcNumBytes = *srcElemType.getSizeInBytes();
int dstNumBytes = *dstElemType.getSizeInBytes();
assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0);
- int ratio = dstNumBytes / srcNumBytes;
- auto ratioValue = rewriter.create<spirv::ConstantOp>(
- loc, i32Type, rewriter.getI32IntegerAttr(ratio));
auto indices = llvm::to_vector<4>(acOp.getIndices());
Value oldIndex = indices.back();
+ Type indexType = oldIndex.getType();
+
+ int ratio = dstNumBytes / srcNumBytes;
+ auto ratioValue = rewriter.create<spirv::ConstantOp>(
+ loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
+
indices.back() =
- rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue);
+ rewriter.create<spirv::SDivOp>(loc, indexType, oldIndex, ratioValue);
indices.push_back(
- rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue));
+ rewriter.create<spirv::SModOp>(loc, indexType, oldIndex, ratioValue));
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
acOp, adaptor.getBasePtr(), indices);
@@ -400,14 +402,17 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
int srcNumBytes = *srcElemType.getSizeInBytes();
int dstNumBytes = *dstElemType.getSizeInBytes();
assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0);
- int ratio = srcNumBytes / dstNumBytes;
- auto ratioValue = rewriter.create<spirv::ConstantOp>(
- loc, i32Type, rewriter.getI32IntegerAttr(ratio));
auto indices = llvm::to_vector<4>(acOp.getIndices());
Value oldIndex = indices.back();
+ Type indexType = oldIndex.getType();
+
+ int ratio = srcNumBytes / dstNumBytes;
+ auto ratioValue = rewriter.create<spirv::ConstantOp>(
+ loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
+
indices.back() =
- rewriter.create<spirv::IMulOp>(loc, i32Type, oldIndex, ratioValue);
+ rewriter.create<spirv::IMulOp>(loc, indexType, oldIndex, ratioValue);
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
acOp, adaptor.getBasePtr(), indices);
diff --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
index 1d532f3f13812..8801fdb0e1653 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
@@ -32,6 +32,33 @@ spirv.module Logical GLSL450 {
// -----
+spirv.module Logical GLSL450 {
+ spirv.GlobalVariable @var01s bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+ spirv.GlobalVariable @var01v bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+
+ spirv.func @load_store_scalar_64bit(%index: i64) -> f32 "None" {
+ %c0 = spirv.Constant 0 : i64
+ %addr = spirv.mlir.addressof @var01s : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+ %ac = spirv.AccessChain %addr[%c0, %index] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i64, i64
+ %value = spirv.Load "StorageBuffer" %ac : f32
+ spirv.Store "StorageBuffer" %ac, %value : f32
+ spirv.ReturnValue %value : f32
+ }
+}
+
+// CHECK-LABEL: spirv.module
+
+// CHECK-NOT: @var01s
+// CHECK: spirv.GlobalVariable @var01v bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+// CHECK-NOT: @var01s
+
+// CHECK: spirv.func @load_store_scalar_64bit(%[[INDEX:.+]]: i64)
+// CHECK-DAG: %[[C4:.+]] = spirv.Constant 4 : i64
+// CHECK: spirv.SDiv %[[INDEX]], %[[C4]] : i64
+// CHECK: spirv.SMod %[[INDEX]], %[[C4]] : i64
+
+// -----
+
spirv.module Logical GLSL450 {
spirv.GlobalVariable @var01s bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>
spirv.GlobalVariable @var01v bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
More information about the Mlir-commits
mailing list