implementing neural; #include "common-def.slang" internal interface IArrayAccessor { internal void atomicAdd(int index, T value) { static_assert(false, "atomicAdd is not supported for IArrayAccessor"); } __subscript(int index)->T { get; set; } } internal extension RWStructuredBuffer.Handle : IArrayAccessor { [ForceInline] override internal void atomicAdd(int index, T value) { __atomic_reduce_add(this[index], value); } } internal extension Ptr : IArrayAccessor { internal __subscript(int index) -> T { [ForceInline] get { return this[index]; } [ForceInline] set { this[index] = newValue; } } [ForceInline] override internal void atomicAdd(int index, T value) { __atomic_reduce_add(this[index], value); } } internal extension Array : IArrayAccessor { internal __subscript(int index) -> T { [ForceInline] get { return this[index]; } [ForceInline] set { this[index] = newValue; } } } VISIBILITY_LEVEL enum AccessOp : uint32_t { READ, WRITE, ACCUMULATE, ATOMIC_ADD, } #define COMMON_TYPE_CONSTRAINTS \ where T : __BuiltinFloatingPointType \ where U : __BuiltinFloatingPointType \ where BufferType : IArrayAccessor [ForceInline] internal static void readOneElement(BufferType buffer, int bufferIdx, int elementIdx, inout uint result) COMMON_TYPE_CONSTRAINTS { const uint shift = BitsShiftPerRead * elementIdx; T convertedValue; convertedValue = __realCast(buffer[bufferIdx]); switch (NBytes) { case 1: result |= uint(bit_cast(convertedValue)) << shift; break; case 2: result |= uint(bit_cast(convertedValue)) << shift; break; case 4: result |= uint(bit_cast(convertedValue)) << shift; break; default: static_assert(false, "Unsupported data type T"); } } [ForceInline] internal static void writeOneElement(inout BufferType buffer, int bufferIdx, int elementIdx, uint value) COMMON_TYPE_CONSTRAINTS { const uint shift = BitsShiftPerWrite * elementIdx; U convertedValue; switch (NBytes) { case 1: convertedValue = __realCast(bit_cast((uint8_t)(value >> shift))); break; case 2: convertedValue = __realCast(bit_cast((uint16_t)(value >> shift))); break; case 4: convertedValue = __realCast(bit_cast((uint)(value >> shift))); break; default: static_assert(false, "Unsupported data type T"); } switch (Op) { case AccessOp.WRITE: buffer[bufferIdx] = convertedValue; break; case AccessOp.ACCUMULATE: buffer[bufferIdx] = buffer[bufferIdx] + convertedValue; break; case AccessOp.ATOMIC_ADD: buffer.atomicAdd(bufferIdx, convertedValue); break; default: static_assert(false, "Unsupported access operation"); } } [ForceInline] internal static void accessUint4Aligned(inout BufferType buffer, int startIndex, inout uint4 value) COMMON_TYPE_CONSTRAINTS { const int nBytes = sizeof(T); const int WritePerElement = 4 / nBytes; const int BitsShiftPerWrite = 32 / WritePerElement; if (Op == AccessOp.READ) value = uint4(0, 0, 0, 0); [ForceUnroll] for (int i = 0; i < 4; i++) { [ForceUnroll] for (int j = 0; j < WritePerElement; j++) { int index = startIndex + i * WritePerElement + j; switch (Op) { case AccessOp.READ: readOneElement(buffer, index, j, value[i]); break; case AccessOp.WRITE: case AccessOp.ACCUMULATE: case AccessOp.ATOMIC_ADD: writeOneElement(buffer, index, j, value[i]); break; default: static_assert(false, "Unsupported access operation"); } } } } [ForceInline] internal void accessUint4(BufferType buffer, int baseIndex, int startIndex, inout uint4 value) COMMON_TYPE_CONSTRAINTS { if (IsAligned) { // Call the aligned version of readUint4 which is branchless. accessUint4Aligned(buffer, startIndex, value); return; } if (Op == AccessOp.READ) value = uint4(0, 0, 0, 0); // T is the type of source (read) or destination (write) data type. We will always pack few elements into a uint4. // So T will determine how many elements we can pack into a uint4. // If U is different from T, we will first convert from U to T (in read operation) or from T to U (in write operation). // But U will not determined how many elements we can read or write, only T will. const int nBytes = sizeof(T); const int ReadPerElement = 4 / nBytes; const int BitsShiftPerRead = 32 / ReadPerElement; const int x = (startIndex - baseIndex) % Stride; // end address of this read [address+length-1] const int endAddress = (x + 4 * ReadPerElement - 1); // this is same as paddingCount = endAddress < AlignedStride ? 0 : AlignedStride - endAddress + 1 const int paddingCount = max(0, endAddress - Stride + 1); const int elementsToRead = (4 * ReadPerElement) - paddingCount; [ForceUnroll] for (int i = 0; i < 4; i++) { int offset = i * ReadPerElement; [ForceUnroll] for (int j = 0; j < ReadPerElement; j++) { // 4 * ReadPerElement is the total number of elements we can read from the buffer. // paddingCount is the number of the elements we need to pad. // e.g. if ReadPerElement is 2, paddingCount is 4.Because (4 * 2 - 4 == 4), so we can // just stop reading when offset bigger than 3. offset += j; if (offset >= elementsToRead) { return; } int index = (startIndex + offset); switch (Op) { case AccessOp.READ: readOneElement(buffer, index, j, value[i]); break; case AccessOp.WRITE: case AccessOp.ACCUMULATE: case AccessOp.ATOMIC_ADD: writeOneElement(buffer, index, j, value[i]); break; default: static_assert(false, "Unsupported access operation"); } } } }