[Mlir-commits] [mlir] 141b7d4 - [mlir][spirv] Fix UnifyAliasedResourcePass for 64-bit index

Lei Zhang llvmlistbot at llvm.org
Tue Mar 14 16:54:36 PDT 2023


Author: Lei Zhang
Date: 2023-03-14T23:54:27Z
New Revision: 141b7d49a3a4365c52f12ab6e38664cebce538b5

URL: https://github.com/llvm/llvm-project/commit/141b7d49a3a4365c52f12ab6e38664cebce538b5
DIFF: https://github.com/llvm/llvm-project/commit/141b7d49a3a4365c52f12ab6e38664cebce538b5.diff

LOG: [mlir][spirv] Fix UnifyAliasedResourcePass for 64-bit index

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D145079

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
    mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
index 3acfba23ddfd1..1713c4490f960 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
@@ -366,7 +366,6 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
     }
 
     Location loc = acOp.getLoc();
-    auto i32Type = rewriter.getI32Type();
 
     if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) {
       // The source indices are for a buffer with scalar element types. Rewrite
@@ -376,16 +375,19 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
       int srcNumBytes = *srcElemType.getSizeInBytes();
       int dstNumBytes = *dstElemType.getSizeInBytes();
       assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0);
-      int ratio = dstNumBytes / srcNumBytes;
-      auto ratioValue = rewriter.create<spirv::ConstantOp>(
-          loc, i32Type, rewriter.getI32IntegerAttr(ratio));
 
       auto indices = llvm::to_vector<4>(acOp.getIndices());
       Value oldIndex = indices.back();
+      Type indexType = oldIndex.getType();
+
+      int ratio = dstNumBytes / srcNumBytes;
+      auto ratioValue = rewriter.create<spirv::ConstantOp>(
+          loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
+
       indices.back() =
-          rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue);
+          rewriter.create<spirv::SDivOp>(loc, indexType, oldIndex, ratioValue);
       indices.push_back(
-          rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue));
+          rewriter.create<spirv::SModOp>(loc, indexType, oldIndex, ratioValue));
 
       rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
           acOp, adaptor.getBasePtr(), indices);
@@ -400,14 +402,17 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
       int srcNumBytes = *srcElemType.getSizeInBytes();
       int dstNumBytes = *dstElemType.getSizeInBytes();
       assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0);
-      int ratio = srcNumBytes / dstNumBytes;
-      auto ratioValue = rewriter.create<spirv::ConstantOp>(
-          loc, i32Type, rewriter.getI32IntegerAttr(ratio));
 
       auto indices = llvm::to_vector<4>(acOp.getIndices());
       Value oldIndex = indices.back();
+      Type indexType = oldIndex.getType();
+
+      int ratio = srcNumBytes / dstNumBytes;
+      auto ratioValue = rewriter.create<spirv::ConstantOp>(
+          loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
+
       indices.back() =
-          rewriter.create<spirv::IMulOp>(loc, i32Type, oldIndex, ratioValue);
+          rewriter.create<spirv::IMulOp>(loc, indexType, oldIndex, ratioValue);
 
       rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
           acOp, adaptor.getBasePtr(), indices);

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
index 1d532f3f13812..8801fdb0e1653 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
@@ -32,6 +32,33 @@ spirv.module Logical GLSL450 {
 
 // -----
 
+spirv.module Logical GLSL450 {
+  spirv.GlobalVariable @var01s bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+  spirv.GlobalVariable @var01v bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+
+  spirv.func @load_store_scalar_64bit(%index: i64) -> f32 "None" {
+    %c0 = spirv.Constant 0 : i64
+    %addr = spirv.mlir.addressof @var01s : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+    %ac = spirv.AccessChain %addr[%c0, %index] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i64, i64
+    %value = spirv.Load "StorageBuffer" %ac : f32
+    spirv.Store "StorageBuffer" %ac, %value : f32
+    spirv.ReturnValue %value : f32
+  }
+}
+
+// CHECK-LABEL: spirv.module
+
+// CHECK-NOT: @var01s
+//     CHECK: spirv.GlobalVariable @var01v bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+// CHECK-NOT: @var01s
+
+//     CHECK: spirv.func @load_store_scalar_64bit(%[[INDEX:.+]]: i64)
+// CHECK-DAG:   %[[C4:.+]] = spirv.Constant 4 : i64
+//     CHECK:   spirv.SDiv %[[INDEX]], %[[C4]] : i64
+//     CHECK:   spirv.SMod %[[INDEX]], %[[C4]] : i64
+
+// -----
+
 spirv.module Logical GLSL450 {
   spirv.GlobalVariable @var01s bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>
   spirv.GlobalVariable @var01v bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>


        


More information about the Mlir-commits mailing list