[Mlir-commits] [mlir] 38f8a3c - [mlir][spirv] Improve folding of MemRef to SPIRV Lowering (#85433)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 21 08:49:31 PDT 2024


Author: Finn Plummer
Date: 2024-03-21T08:49:27-07:00
New Revision: 38f8a3cf0d75cd25e13d3757027f7356e4466cb9

URL: https://github.com/llvm/llvm-project/commit/38f8a3cf0d75cd25e13d3757027f7356e4466cb9
DIFF: https://github.com/llvm/llvm-project/commit/38f8a3cf0d75cd25e13d3757027f7356e4466cb9.diff

LOG: [mlir][spirv] Improve folding of MemRef to SPIRV Lowering (#85433)

Investigate the lowering of MemRef Load/Store ops and implement
additional folding of created ops

Aims to improve readability of generated lowered SPIR-V code.

Part of work llvm#70704

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/GPUToSPIRV/load-store.mlir
    mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
    mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
    mlir/test/Conversion/SCFToSPIRV/for.mlir
    mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 0acb2142f3f68a..81b9f55cac80f7 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -50,11 +50,12 @@ static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
   assert(targetBits % sourceBits == 0);
   Type type = srcIdx.getType();
   IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
-  auto idx = builder.create<spirv::ConstantOp>(loc, type, idxAttr);
+  auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
   IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
-  auto srcBitsValue = builder.create<spirv::ConstantOp>(loc, type, srcBitsAttr);
-  auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
-  return builder.create<spirv::IMulOp>(loc, type, m, srcBitsValue);
+  auto srcBitsValue =
+      builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
+  auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
+  return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
 }
 
 /// Returns an adjusted spirv::AccessChainOp. Based on the
@@ -74,11 +75,11 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
   Value lastDim = op->getOperand(op.getNumOperands() - 1);
   Type type = lastDim.getType();
   IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
-  auto idx = builder.create<spirv::ConstantOp>(loc, type, attr);
+  auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
   auto indices = llvm::to_vector<4>(op.getIndices());
   // There are two elements if this is a 1-D tensor.
   assert(indices.size() == 2);
-  indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
+  indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
   Type t = typeConverter.convertType(op.getComponentPtr().getType());
   return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
 }
@@ -91,7 +92,8 @@ static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
     return srcBool;
   Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
   Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
-  return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
+  return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
+                                               zero);
 }
 
 /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
@@ -111,10 +113,10 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
           loc, builder.getIntegerType(targetBits), value);
     }
 
-    value = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
+    value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
   }
-  return builder.create<spirv::ShiftLeftLogicalOp>(loc, value.getType(), value,
-                                                   offset);
+  return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
+                                                         value, offset);
 }
 
 /// Returns true if the allocations of memref `type` generated from `allocOp`
@@ -165,7 +167,7 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
     return srcInt;
 
   auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
-  return builder.create<spirv::IEqualOp>(loc, srcInt, one);
+  return builder.createOrFold<spirv::IEqualOp>(loc, srcInt, one);
 }
 
 //===----------------------------------------------------------------------===//
@@ -597,13 +599,14 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   // ____XXXX________ -> ____________XXXX
   Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
   Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
-  Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
+  Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
       loc, spvLoadOp.getType(), spvLoadOp, offset);
 
   // Apply the mask to extract corresponding bits.
-  Value mask = rewriter.create<spirv::ConstantOp>(
+  Value mask = rewriter.createOrFold<spirv::ConstantOp>(
       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
-  result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
+  result =
+      rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
 
   // Apply sign extension on the loading value unconditionally. The signedness
   // semantic is carried in the operator itself, we relies other pattern to
@@ -611,11 +614,11 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   IntegerAttr shiftValueAttr =
       rewriter.getIntegerAttr(dstType, dstBits - srcBits);
   Value shiftValue =
-      rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
-  result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
-                                                      shiftValue);
-  result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
-                                                          shiftValue);
+      rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
+  result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
+                                                            result, shiftValue);
+  result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
+      loc, dstType, result, shiftValue);
 
   rewriter.replaceOp(loadOp, result);
 
@@ -744,11 +747,12 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
 
   // Create a mask to clear the destination. E.g., if it is the second i8 in
   // i32, 0xFFFF00FF is created.
-  Value mask = rewriter.create<spirv::ConstantOp>(
+  Value mask = rewriter.createOrFold<spirv::ConstantOp>(
       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
-  Value clearBitsMask =
-      rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
-  clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
+  Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
+      loc, dstType, mask, offset);
+  clearBitsMask =
+      rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
 
   Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
@@ -910,7 +914,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
 
     int64_t attrVal = cast<IntegerAttr>(offset.get<Attribute>()).getInt();
     Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
-    return rewriter.create<spirv::ConstantOp>(loc, intType, attr);
+    return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
   }();
 
   rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2b79c8022b8e85..4072608dc8f873 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -991,15 +991,16 @@ Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
   // broken down into progressive small steps so we can have intermediate steps
   // using other dialects. At the moment SPIR-V is the final sink.
 
-  Value linearizedIndex = builder.create<spirv::ConstantOp>(
+  Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
       loc, integerType, IntegerAttr::get(integerType, offset));
   for (const auto &index : llvm::enumerate(indices)) {
-    Value strideVal = builder.create<spirv::ConstantOp>(
+    Value strideVal = builder.createOrFold<spirv::ConstantOp>(
         loc, integerType,
         IntegerAttr::get(integerType, strides[index.index()]));
-    Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
+    Value update =
+        builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
     linearizedIndex =
-        builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
+        builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
   }
   return linearizedIndex;
 }

diff  --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
index fa12da8ef9d4ec..4339799ccd5eaf 100644
--- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
@@ -60,13 +60,9 @@ module attributes {
       // CHECK: %[[INDEX2:.*]] = spirv.IAdd %[[ARG4]], %[[LOCALINVOCATIONIDX]]
       %13 = arith.addi %arg4, %3 : index
       // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
-      // CHECK: %[[OFFSET1_0:.*]] = spirv.Constant 0 : i32
       // CHECK: %[[STRIDE1_1:.*]] = spirv.Constant 4 : i32
-      // CHECK: %[[UPDATE1_1:.*]] = spirv.IMul %[[STRIDE1_1]], %[[INDEX1]] : i32
-      // CHECK: %[[OFFSET1_1:.*]] = spirv.IAdd %[[OFFSET1_0]], %[[UPDATE1_1]] : i32
-      // CHECK: %[[STRIDE1_2:.*]] = spirv.Constant 1 : i32
-      // CHECK: %[[UPDATE1_2:.*]] = spirv.IMul %[[STRIDE1_2]], %[[INDEX2]] : i32
-      // CHECK: %[[OFFSET1_2:.*]] = spirv.IAdd %[[OFFSET1_1]], %[[UPDATE1_2]] : i32
+      // CHECK: %[[UPDATE1_1:.*]] = spirv.IMul %[[INDEX1]], %[[STRIDE1_1]] : i32
+      // CHECK: %[[OFFSET1_2:.*]] = spirv.IAdd %[[INDEX2]], %[[UPDATE1_1]] : i32
       // CHECK: %[[PTR1:.*]] = spirv.AccessChain %[[ARG0]]{{\[}}%[[ZERO]], %[[OFFSET1_2]]{{\]}}
       // CHECK-NEXT: %[[VAL1:.*]] = spirv.Load "StorageBuffer" %[[PTR1]]
       %14 = memref.load %arg0[%12, %13] : memref<12x4xf32, #spirv.storage_class<StorageBuffer>>

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
index 470c8531e2e0fb..52ed14e8cce233 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir
@@ -12,16 +12,10 @@ module attributes {
 // CHECK-LABEL: @load_i1
 func.func @load_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>) -> i1 {
   //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  //     CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
-  //     CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]]
   //     CHECK: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]]
-  //     CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
-  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
-  //     CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
   //     CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
-  //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
   //     CHECK: %[[T2:.+]] = spirv.Constant 24 : i32
   //     CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
   //     CHECK: %[[T4:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
@@ -37,32 +31,20 @@ func.func @load_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>) -> i1
 // INDEX64-LABEL: @load_i8
 func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 {
   //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  //     CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
-  //     CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]]
   //     CHECK: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]]
-  //     CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
-  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
-  //     CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
   //     CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
-  //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
   //     CHECK: %[[T2:.+]] = spirv.Constant 24 : i32
   //     CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
   //     CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
   //     CHECK: builtin.unrealized_conversion_cast %[[SR]]
 
   //   INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
-  //   INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
-  //   INDEX64: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
-  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] : {{.+}}, i64, i64
+  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64
   //   INDEX64: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]] : i32
-  //   INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
-  //   INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
-  //   INDEX64: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
-  //   INDEX64: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i64
   //   INDEX64: %[[MASK:.+]] = spirv.Constant 255 : i32
-  //   INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //   INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
   //   INDEX64: %[[T2:.+]] = spirv.Constant 24 : i32
   //   INDEX64: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
   //   INDEX64: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
@@ -76,15 +58,12 @@ func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8
 func.func @load_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>, %index : index) -> i16 {
   //     CHECK: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
   //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  //     CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
-  //     CHECK: %[[UPDATE:.+]] = spirv.IMul %[[ONE]], %[[ARG1_CAST]] : i32
-  //     CHECK: %[[FLAT_IDX:.+]] = spirv.IAdd %[[ZERO]], %[[UPDATE]] : i32
   //     CHECK: %[[TWO:.+]] = spirv.Constant 2 : i32
-  //     CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[FLAT_IDX]], %[[TWO]] : i32
+  //     CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ARG1_CAST]], %[[TWO]] : i32
   //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
   //     CHECK: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]]
   //     CHECK: %[[SIXTEEN:.+]] = spirv.Constant 16 : i32
-  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[FLAT_IDX]], %[[TWO]] : i32
+  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ARG1_CAST]], %[[TWO]] : i32
   //     CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[SIXTEEN]] : i32
   //     CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
   //     CHECK: %[[MASK:.+]] = spirv.Constant 65535 : i32
@@ -110,20 +89,12 @@ func.func @load_f32(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>) {
 func.func @store_i1(%arg0: memref<i1, #spirv.storage_class<StorageBuffer>>, %value: i1) {
   //     CHECK: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
   //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  //     CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
-  //     CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
-  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
-  //     CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32
-  //     CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
-  //     CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+  //     CHECK: %[[MASK:.+]] = spirv.Constant -256 : i32
   //     CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
   //     CHECK: %[[CASTED_ARG1:.+]] = spirv.Select %[[ARG1]], %[[ONE]], %[[ZERO]] : i1, i32
-  //     CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CASTED_ARG1]], %[[OFFSET]] : i32, i32
-  //     CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
+  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]]
   //     CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
-  //     CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+  //     CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CASTED_ARG1]]
   memref.store %value, %arg0[] : memref<i1, #spirv.storage_class<StorageBuffer>>
   return
 }
@@ -136,36 +107,22 @@ func.func @store_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>, %val
   //     CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
   //     CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
   //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  //     CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
-  //     CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
-  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
   //     CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32
-  //     CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
-  //     CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+  //     CHECK: %[[MASK2:.+]] = spirv.Constant -256 : i32
   //     CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
-  //     CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
-  //     CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
-  //     CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
-  //     CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]]
+  //     CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]]
+  //     CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CLAMPED_VAL]]
 
   //   INDEX64-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
   //   INDEX64-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
   //   INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
-  //   INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
-  //   INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
-  //   INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
-  //   INDEX64: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
   //   INDEX64: %[[MASK1:.+]] = spirv.Constant 255 : i32
-  //   INDEX64: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i64
-  //   INDEX64: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+  //   INDEX64: %[[MASK2:.+]] = spirv.Constant -256 : i32
   //   INDEX64: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
-  //   INDEX64: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i64
-  //   INDEX64: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
-  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] : {{.+}}, i64, i64
-  //   INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
-  //   INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64
+  //   INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]]
+  //   INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[CLAMPED_VAL]]
   memref.store %value, %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
   return
 }
@@ -177,19 +134,16 @@ func.func @store_i16(%arg0: memref<10xi16, #spirv.storage_class<StorageBuffer>>,
   //     CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
   //     CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
   //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  //     CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
-  //     CHECK: %[[UPDATE:.+]] = spirv.IMul %[[ONE]], %[[ARG1_CAST]] : i32
-  //     CHECK: %[[FLAT_IDX:.+]] = spirv.IAdd %[[ZERO]], %[[UPDATE]] : i32
   //     CHECK: %[[TWO:.+]] = spirv.Constant 2 : i32
   //     CHECK: %[[SIXTEEN:.+]] = spirv.Constant 16 : i32
-  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[FLAT_IDX]], %[[TWO]] : i32
+  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ARG1_CAST]], %[[TWO]] : i32
   //     CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[SIXTEEN]] : i32
   //     CHECK: %[[MASK1:.+]] = spirv.Constant 65535 : i32
   //     CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
   //     CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
   //     CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG2_CAST]], %[[MASK1]] : i32
   //     CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
-  //     CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[FLAT_IDX]], %[[TWO]] : i32
+  //     CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ARG1_CAST]], %[[TWO]] : i32
   //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
   //     CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
   //     CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
@@ -222,15 +176,12 @@ module attributes {
 func.func @load_i4(%arg0: memref<?xi4, #spirv.storage_class<StorageBuffer>>, %i: index) -> i4 {
   // CHECK: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %{{.+}} : index to i32
   // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
-  // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[INDEX]] : i32
-  // CHECK: %[[OFFSET:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32
   // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
-  // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[OFFSET]], %[[EIGHT]] : i32
+  // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[INDEX]], %[[EIGHT]] : i32
   // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
   // CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] : i32
   // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
-  // CHECK: %[[IDX:.+]] = spirv.UMod %[[OFFSET]], %[[EIGHT]] : i32
+  // CHECK: %[[IDX:.+]] = spirv.UMod %[[INDEX]], %[[EIGHT]] : i32
   // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[FOUR]] : i32
   // CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
   // CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32
@@ -248,19 +199,16 @@ func.func @store_i4(%arg0: memref<?xi4, #spirv.storage_class<StorageBuffer>>, %v
   // CHECK: %[[VAL:.+]] = builtin.unrealized_conversion_cast %{{.+}} : i4 to i32
   // CHECK: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %{{.+}} : index to i32
   // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
-  // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[INDEX]] : i32
-  // CHECK: %[[OFFSET:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32
   // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
-  // CHECK: %[[FOUR:.+]] = spirv.Constant [[OFFSET]] : i32
-  // CHECK: %[[IDX:.+]] = spirv.UMod %[[OFFSET]], %[[EIGHT]] : i32
+  // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
+  // CHECK: %[[IDX:.+]] = spirv.UMod %[[INDEX]], %[[EIGHT]] : i32
   // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[FOUR]] : i32
   // CHECK: %[[MASK1:.+]] = spirv.Constant 15 : i32
   // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[BITS]] : i32, i32
   // CHECK: %[[MASK2:.+]] = spirv.Not %[[SL]] : i32
   // CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[VAL]], %[[MASK1]] : i32
   // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[BITS]] : i32, i32
-  // CHECK: %[[ACCESS_INDEX:.+]] = spirv.SDiv %[[OFFSET]], %[[EIGHT]] : i32
+  // CHECK: %[[ACCESS_INDEX:.+]] = spirv.SDiv %[[INDEX]], %[[EIGHT]] : i32
   // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ACCESS_INDEX]]]
   // CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK2]]
   // CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
@@ -283,16 +231,10 @@ module attributes {
 // INDEX64-LABEL: @load_i8
 func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8 {
   //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  //     CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
-  //     CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]]
   //     CHECK: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]]
-  //     CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
-  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
-  //     CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
   //     CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32
-  //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
   //     CHECK: %[[T2:.+]] = spirv.Constant 24 : i32
   //     CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
   //     CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
@@ -300,16 +242,10 @@ func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8
   //     CHECK: return %[[CAST]] : i8
 
   //   INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
-  //   INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
-  //   INDEX64: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
-  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] : {{.+}}, i64, i64
+  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64
   //   INDEX64: %[[LOAD:.+]] = spirv.Load  "StorageBuffer" %[[PTR]] : i32
-  //   INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
-  //   INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
-  //   INDEX64: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
-  //   INDEX64: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i64
   //   INDEX64: %[[MASK:.+]] = spirv.Constant 255 : i32
-  //   INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //   INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[LOAD]], %[[MASK]] : i32
   //   INDEX64: %[[T2:.+]] = spirv.Constant 24 : i32
   //   INDEX64: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
   //   INDEX64: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
@@ -326,37 +262,19 @@ func.func @load_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>) -> i8
 func.func @store_i8(%arg0: memref<i8, #spirv.storage_class<StorageBuffer>>, %value: i8) {
   //     CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
   //     CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  //     CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
-  //     CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
-  //     CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32
-  //     CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32
-  //     CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
-  //     CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+  //     CHECK: %[[MASK1:.+]] = spirv.Constant -256 : i32
   //     CHECK: %[[ARG1_CAST:.+]] = spirv.UConvert %[[ARG1]] : i8 to i32
-  //     CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
-  //     CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
-  //     CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32
-  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
-  //     CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
-  //     CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+  //     CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]]
+  //     CHECK: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK1]]
+  //     CHECK: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[ARG1_CAST]]
 
   //   INDEX64-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
   //   INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64
-  //   INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64
-  //   INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64
-  //   INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64
-  //   INDEX64: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64
-  //   INDEX64: %[[MASK1:.+]] = spirv.Constant 255 : i32
-  //   INDEX64: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i64
-  //   INDEX64: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32
+  //   INDEX64: %[[MASK1:.+]] = spirv.Constant -256 : i32
   //   INDEX64: %[[ARG1_CAST:.+]] = spirv.UConvert %[[ARG1]] : i8 to i32
-  //   INDEX64: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32
-  //   INDEX64: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i64
-  //   INDEX64: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64
-  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] : {{.+}}, i64, i64
-  //   INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK]]
-  //   INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[STORE_VAL]]
+  //   INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ZERO]]] : {{.+}}, i64, i64
+  //   INDEX64: spirv.AtomicAnd <Device> <AcquireRelease> %[[PTR]], %[[MASK1]]
+  //   INDEX64: spirv.AtomicOr <Device> <AcquireRelease> %[[PTR]], %[[ARG1_CAST]]
   memref.store %value, %arg0[] : memref<i8, #spirv.storage_class<StorageBuffer>>
   return
 }

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index feb6d4e924015f..10c03a270005f1 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -70,11 +70,8 @@ func.func @load_store_unknown_dim(%i: index, %source: memref<?xi32, #spirv.stora
 func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i : index) -> i1 {
   // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i8, stride=1> [0])>, StorageBuffer>
   // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
-  // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
-  // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32
-  // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32
-  // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[ADD]]]
+  // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
+  // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]]
   // CHECK: %[[VAL:.+]] = spirv.Load "StorageBuffer" %[[ADDR]] : i8
   // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
   // CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8
@@ -90,15 +87,10 @@ func.func @store_i1(%dst: memref<4xi1, #spirv.storage_class<StorageBuffer>>, %i:
   %true = arith.constant true
   // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x i8, stride=1> [0])>, StorageBuffer>
   // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
-  // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
-  // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32
-  // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32
-  // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[ZERO]], %[[ADD]]]
-  // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
+  // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
+  // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[ZERO]], %[[IDX_CAST]]]
   // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
-  // CHECK: %[[RES:.+]] = spirv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8
-  // CHECK: spirv.Store "StorageBuffer" %[[ADDR]], %[[RES]] : i8
+  // CHECK: spirv.Store "StorageBuffer" %[[ADDR]], %[[ONE_I8]] : i8
   memref.store %true, %dst[%i]: memref<4xi1, #spirv.storage_class<StorageBuffer>>
   return
 }
@@ -234,11 +226,7 @@ func.func @load_store_unknown_dim(%i: index, %source: memref<?xi32, #spirv.stora
 func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i : index) -> i1 {
   // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<4 x i8>, CrossWorkgroup>
   // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
-  // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
-  // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32
-  // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32
-  // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ADD]]]
+  // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[IDX_CAST]]]
   // CHECK: %[[VAL:.+]] = spirv.Load "CrossWorkgroup" %[[ADDR]] : i8
   // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
   // CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8
@@ -254,15 +242,9 @@ func.func @store_i1(%dst: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i
   %true = arith.constant true
   // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<4 x i8>, CrossWorkgroup>
   // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
-  // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
-  // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32
-  // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32
-  // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[ADD]]]
-  // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
+  // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[IDX_CAST]]]
   // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
-  // CHECK: %[[RES:.+]] = spirv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8
-  // CHECK: spirv.Store "CrossWorkgroup" %[[ADDR]], %[[RES]] : i8
+  // CHECK: spirv.Store "CrossWorkgroup" %[[ADDR]], %[[ONE_I8]] : i8
   memref.store %true, %dst[%i]: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>
   return
 }

diff  --git a/mlir/test/Conversion/SCFToSPIRV/for.mlir b/mlir/test/Conversion/SCFToSPIRV/for.mlir
index 02558463b8662d..81661ec7a3a060 100644
--- a/mlir/test/Conversion/SCFToSPIRV/for.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/for.mlir
@@ -19,17 +19,9 @@ func.func @loop_kernel(%arg2 : memref<10xf32, #spirv.storage_class<StorageBuffer
   // CHECK:        spirv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]]
   // CHECK:      ^[[BODY]]:
   // CHECK:        %[[ZERO1:.*]] = spirv.Constant 0 : i32
-  // CHECK:        %[[OFFSET1:.*]] = spirv.Constant 0 : i32
-  // CHECK:        %[[STRIDE1:.*]] = spirv.Constant 1 : i32
-  // CHECK:        %[[UPDATE1:.*]] = spirv.IMul %[[STRIDE1]], %[[INDVAR]] : i32
-  // CHECK:        %[[INDEX1:.*]] = spirv.IAdd %[[OFFSET1]], %[[UPDATE1]] : i32
-  // CHECK:        spirv.AccessChain {{%.*}}{{\[}}%[[ZERO1]], %[[INDEX1]]{{\]}}
+  // CHECK:        spirv.AccessChain {{%.*}}{{\[}}%[[ZERO1]], %[[INDVAR]]{{\]}}
   // CHECK:        %[[ZERO2:.*]] = spirv.Constant 0 : i32
-  // CHECK:        %[[OFFSET2:.*]] = spirv.Constant 0 : i32
-  // CHECK:        %[[STRIDE2:.*]] = spirv.Constant 1 : i32
-  // CHECK:        %[[UPDATE2:.*]] = spirv.IMul %[[STRIDE2]], %[[INDVAR]] : i32
-  // CHECK:        %[[INDEX2:.*]] = spirv.IAdd %[[OFFSET2]], %[[UPDATE2]] : i32
-  // CHECK:        spirv.AccessChain {{%.*}}[%[[ZERO2]], %[[INDEX2]]]
+  // CHECK:        spirv.AccessChain {{%.*}}[%[[ZERO2]], %[[INDVAR]]]
   // CHECK:        %[[INCREMENT:.*]] = spirv.IAdd %[[INDVAR]], %[[STEP]] : i32
   // CHECK:        spirv.Branch ^[[HEADER]](%[[INCREMENT]] : i32)
   // CHECK:      ^[[MERGE]]

diff  --git a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
index 19de613bf5b073..32d0fbea65b164 100644
--- a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
@@ -14,14 +14,12 @@ func.func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
   // CHECK: spirv.Store "Function" %[[VAR]], %[[CST]] : !spirv.array<12 x i32>
   // CHECK: %[[C0:.+]] = spirv.Constant 0 : i32
   // CHECK: %[[C6:.+]] = spirv.Constant 6 : i32
-  // CHECK: %[[MUL0:.+]] = spirv.IMul %[[C6]], %[[A]] : i32
-  // CHECK: %[[ADD0:.+]] = spirv.IAdd %[[C0]], %[[MUL0]] : i32
+  // CHECK: %[[MUL0:.+]] = spirv.IMul %[[A]], %[[C6]] : i32
   // CHECK: %[[C3:.+]] = spirv.Constant 3 : i32
-  // CHECK: %[[MUL1:.+]] = spirv.IMul %[[C3]], %[[B]] : i32
-  // CHECK: %[[ADD1:.+]] = spirv.IAdd %[[ADD0]], %[[MUL1]] : i32
+  // CHECK: %[[MUL1:.+]] = spirv.IMul %[[B]], %[[C3]] : i32
+  // CHECK: %[[ADD1:.+]] = spirv.IAdd %[[MUL1]], %[[MUL0]] : i32
   // CHECK: %[[C1:.+]] = spirv.Constant 1 : i32
-  // CHECK: %[[MUL2:.+]] = spirv.IMul %[[C1]], %[[C]] : i32
-  // CHECK: %[[ADD2:.+]] = spirv.IAdd %[[ADD1]], %[[MUL2]] : i32
+  // CHECK: %[[ADD2:.+]] = spirv.IAdd %[[C]], %[[ADD1]] : i32
   // CHECK: %[[AC:.+]] = spirv.AccessChain %[[VAR]][%[[ADD2]]]
   // CHECK: %[[VAL:.+]] = spirv.Load "Function" %[[AC]] : i32
   %extract = tensor.extract %cst[%a, %b, %c] : tensor<2x2x3xi32>

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index c9984091d5acc6..cddc4ee385357d 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -720,9 +720,7 @@ module attributes {
 //       CHECK:   %[[CST1:.+]] = spirv.Constant 0 : i32
 //       CHECK:   %[[CST2:.+]] = spirv.Constant 0 : i32
 //       CHECK:   %[[CST3:.+]] = spirv.Constant 1 : i32
-//       CHECK:   %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32
-//       CHECK:   %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32
-//       CHECK:   %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK:   %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S1]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
 //       CHECK:   %[[S5:.+]] = spirv.Bitcast %[[S4]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
 //       CHECK:   %[[R0:.+]] = spirv.Load "StorageBuffer" %[[S5]] : vector<4xf32>
 //       CHECK:   return %[[R0]] : vector<4xf32>
@@ -743,11 +741,9 @@ func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>
 //       CHECK:   %[[CST0_1:.+]] = spirv.Constant 0 : i32
 //       CHECK:   %[[CST0_2:.+]] = spirv.Constant 0 : i32
 //       CHECK:   %[[CST4:.+]] = spirv.Constant 4 : i32
-//       CHECK:   %[[S3:.+]] = spirv.IMul %[[CST4]], %[[S1]] : i32
-//       CHECK:   %[[S4:.+]] = spirv.IAdd %[[CST0_2]], %[[S3]] : i32
+//       CHECK:   %[[S3:.+]] = spirv.IMul %[[S1]], %[[CST4]] : i32
 //       CHECK:   %[[CST1:.+]] = spirv.Constant 1 : i32
-//       CHECK:   %[[S5:.+]] = spirv.IMul %[[CST1]], %[[S2]] : i32
-//       CHECK:   %[[S6:.+]] = spirv.IAdd %[[S4]], %[[S5]] : i32
+//       CHECK:   %[[S6:.+]] = spirv.IAdd  %[[S2]], %[[S3]] : i32
 //       CHECK:   %[[S7:.+]] = spirv.AccessChain %[[S0]][%[[CST0_1]], %[[S6]]] : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
 //       CHECK:   %[[S8:.+]] = spirv.Bitcast %[[S7]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
 //       CHECK:   %[[R0:.+]] = spirv.Load "StorageBuffer" %[[S8]] : vector<4xf32>
@@ -768,9 +764,7 @@ func.func @vector_load_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageBu
 //       CHECK:   %[[CST1:.+]] = spirv.Constant 0 : i32
 //       CHECK:   %[[CST2:.+]] = spirv.Constant 0 : i32
 //       CHECK:   %[[CST3:.+]] = spirv.Constant 1 : i32
-//       CHECK:   %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32
-//       CHECK:   %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32
-//       CHECK:   %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK:   %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S1]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
 //       CHECK:   %[[S5:.+]] = spirv.Bitcast %[[S4]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
 //       CHECK:   spirv.Store "StorageBuffer" %[[S5]], %[[ARG1]] : vector<4xf32>
 func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
@@ -790,11 +784,9 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer
 //       CHECK:   %[[CST0_1:.+]] = spirv.Constant 0 : i32
 //       CHECK:   %[[CST0_2:.+]] = spirv.Constant 0 : i32
 //       CHECK:   %[[CST4:.+]] = spirv.Constant 4 : i32
-//       CHECK:   %[[S3:.+]] = spirv.IMul %[[CST4]], %[[S1]] : i32
-//       CHECK:   %[[S4:.+]] = spirv.IAdd %[[CST0_2]], %[[S3]] : i32
+//       CHECK:   %[[S3:.+]] = spirv.IMul %[[S1]], %[[CST4]] : i32
 //       CHECK:   %[[CST1:.+]] = spirv.Constant 1 : i32
-//       CHECK:   %[[S5:.+]] = spirv.IMul %[[CST1]], %[[S2]] : i32
-//       CHECK:   %[[S6:.+]] = spirv.IAdd %[[S4]], %[[S5]] : i32
+//       CHECK:   %[[S6:.+]] = spirv.IAdd %[[S2]], %[[S3]] : i32
 //       CHECK:   %[[S7:.+]] = spirv.AccessChain %[[S0]][%[[CST0_1]], %[[S6]]] : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
 //       CHECK:   %[[S8:.+]] = spirv.Bitcast %[[S7]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
 //       CHECK:   spirv.Store "StorageBuffer" %[[S8]], %[[ARG1]] : vector<4xf32>


        


More information about the Mlir-commits mailing list