SIMT
The programming model of a GPU uses what has been coined "single instruction multiple thread." The idea is that the programmer writes their program from the perspective of a single thread, using normal regular variables. So, for example, a programmer might write something like:
int x = threadID;
int y = 6;
int z = x + y;
Straightforward, right? Then, they ask the system to run this program a million times, in parallel, with different threadIDs.
The system *could* simply schedule a million threads to do this, but GPUs do better than this. Instead, the compiler will transparently rewrite the program to use vector registers and instructions in order to run multiple "threads" at the same time. So, imagine you have a vector register, where each item in the vector represents a scalar from a particular "thread." In the above, program, x corresponds to a vector of [0, 1, 2, 3, etc.] and y corresponds to a vector of [6, 6, 6, 6, etc.]. Then, the operation x + y is simply a single vector add operation of both vectors. This way, performance can be dramatically improved, because these vector operations are usually significantly faster than if you had performed each scalar operation one-by-one.
(This is in contrast to SIMD, or "single instruction multiple data," where the programmer explicitly uses vector types and operations in their program. The SIMD approach is suited for when you have a single program that has to process a lot of data, whereas SIMT is suited for when you have many programs and each one operates on its own data.)
SIMT gets complicated, though, when you have control flow. Imagine the program did something like:
if (threadID < 4) {
doSomethingObservable();
}
Here, the system has to behave as-if threads 0-3 executed the "then" block, but also behave as-if threads 4-n didn't execute it. And, of course, thread 0-3 want to take advantage of vector operations - you don't want to pessimize and run each thread serially. So, what do you do?
Well, the way that GPUs handle this is by using predicated instructions. There is a bitmask which indicates which "threads" are alive: within the above "then" block, that bitmask will have value 0xF. Then, all the vector instructions use this bitmask to determine which elements of the vector it should actually operate on. So, if the bitmask is 0xF, and you execute a vector add operation, the vector add operation is only going to perform the add on the 0th-3rd items in the vector. (Or, at least, it will behave "as-if" it only performed the operation on those items, from an observability perspective.) So, the way that control flow like this works is: all threads actually execute the "then" block, but all the operations in the block are predicated on a bitmask which specifies that only certain items in the vector operations should actually be performed. The "if" statement itself just modifies the bitmask.
The Project
AVX-512 is an optional instruction set on some (fairly rare) x86_64 machines. The exciting thing about AVX-512 is that it adds support for this predication bitmask thing. It has a bunch of vector registers (512 bits wide, named zmm0 - zmm31) and it also adds a set of predication bitmask registers (k0 - k7). The instructions that act upon the vector registers can be predicated on the value of one of those predication registers, to achieve the effect of SIMT.
It turns out I actually have a machine lying around in my home which supports AVX-512, so I thought I'd give it a go, and actually implement a compiler that compiles a toy language, but performs the SIMT transformation to use the vector operations and predication registers. The purpose of this exercise isn't really to achieve incredible performance - there are lots of sophisticated compiler optimizations which I am not really interested in implementing - but instead the purpose is really just as a learning exercise. Hopefully, by implementing this transformation myself for a toy language, I can learn more about the kinds of things that real GPU compilers do.
The toy language is one I invented myself - it's very similar to C, with some syntax that's slightly easier to parse. Programs look like this:
function main(index: uint64): uint64 {
variable accumulator: uint64 = 0;
variable accumulatorPointer: pointer<uint64> = &accumulator;
for (variable i: uint64 = 0; i < index; i = i + 1) {
accumulator = *accumulatorPointer + i;
}
return accumulator;
}
It's pretty straightforward. It doesn't have things like ++ or +=. It also doesn't have floating-point numbers (which is fine, because AVX-512 supports vector integer operations). It has pointers, for loops, continue & break statements, early returns... the standard stuff.
Tour
Let's take a tour, and examine how each piece of a C-like language gets turned into AVX-512 SIMT. I implemented this so it can run real programs, and tested it somewhat-rigorously - enough to be fairly convinced that it's generally right and correct.
Variables and Simple Math
The most straightforward part of this system is variables and literal math. Consider:
variable accumulator: uint64;
This is a variable declaration. Each thread may store different values into the variable, so its storage needs to be a vector. No problem, right?
What about if the variable's type is a complex type? Consider:
struct Foo {
x: uint64;
y: uint64;
}
variable bar: Foo;
Here, we need to maintain the invariant that Foo.x has the same memory layout as any other uint64. This means that, rather than alternating x,y,x,y,x,y in memory, there instead has to be a vector for all the threads' x value, followed by another vector for all the threads' y values. This works recursively: if a struct has other structs inside it, the compiler will to through all the leaf-types in the tree, turn each leaf type into a vector, and then lay them out in memory end-to-end.
Simple math is even more straightforward. Literal numbers have the same value no matter which thread you're running, so they just turn into broadcast instructions. The program says "3" and the instruction that gets executed is "broadcast 3 to every item in a vector". Easy peasy.
L-values and R-values
In a C-like language, every value is categorized as either an "l-value" or an "r-value". An l-value is defined as having a location in memory, and r-values don't have a location in memory. The value produced by the expression "2 + 3" is an r-value, but the value produced by the expression "*foo()" is an l-value, because you dereferenced the pointer, so the thing the pointer points to is the location in memory of the resulting value. L-values can be assigned to; r-values cannot be assigned to. So, you can say things like "foo = 3 + 4;" (because "foo" refers to a variable, which has a memory location) but you can't say "3 + 4 = foo;". That's why it's called "l-value" and "r-value" - l-values are legal on the left side of an assignment.
At runtime, every expression has to produce some value, which is consumed by its parent in the AST. E.g, in "3 * 4 + 5", the "3 * 4" has to produce a "12" which the "+" will consume. The simplest way to handle l-values is to make them produce a pointer. This is so expressions like "&foo" work - the "foo" is an lvalue and produces a pointer that points to the variable's storage, and the & operator receives this pointer and produces that same pointer (unmodified!) as an r-value. The same thing happens in reverse for the unary * ("dereference") operator: it accepts an r-value of pointer type, and produces an l-value - which is just the pointer it just received. This is how expressions like "*&*&*&*&*&foo = 7;" work (which is totally legal and valid C!): the "foo" produces a pointer, which the & operator accepts and passes through untouched to the &, which takes it and passes it through untouched, all the way to the final *, which produces the same pointer as an lvalue, that points to the storage of foo.
The assignment operator knows that the thing on its left side must be an lvalue and therefore will always produce a pointer, so that's the storage that the assignment stores into. The right side can either be an l-value or an r-value; if it's an l-value, the assignment operator has to read from the thing it points to; otherwise, it's an r-value, and the assignment operator reads the value itself. This is generalized to every operation: it's legal to say "foo + 3", so the + operator needs to determine which of its parameters are l-values, and will thus produce pointers instead of values, and it will need to react accordingly to read from the storage the pointers point to.
All this stuff means that, even for simple programs where the author didn't even spell the name "pointer" anywhere in the program, or even use the * or & operators anywhere in the program, there will still be pointers internally used just by virtue of the fact that there will be l-values used in the program. So, dealing with pointers is a core part of the language. They appear everywhere, whether the program author wants them to or not.
Pointers
If we now think about what this means for SIMT, l-values produce pointers, but each thread has to get its own distinct pointer! That's because of programs like this:
variable x: pointer<uint64>;
if (...) {
x = &something;
} else {
x = &somethingElse;
}
*x = 4;
That *x expression is an l-value. It's not special - it's just like any other l-value. The assignment operator needs to handle the fact that, in SIMT, the lvalue that *x produces is a vector of pointers, where each pointer can potentially be distinct. Therefore, that assignment operator doesn't actually perform a single vector store; instead, it performs a "scatter" operation. There's a vector of pointers, and there's a vector of values to store to those pointers; the assignment operator might end up spraying those values all around memory. In AVX-512, there's an instruction that does this scatter operation.
(Aside: That scatter operation in AVX-512 uses a predication mask register (of course), but the instruction has a side-effect of clearing that register. That kind of sucks from the programmer's point of view - the program has to save and restore the value of the register just because of a quirk of this instruction. But then, thinking about it more, I realized that the memory operation might cause a page fault, which has to be handled by the operating system. The operating system therefore needs to know which address triggered the page fault, so it knows which pages to load. The predication register holds this information - as each memory access completes, the corresponding bit in the predication register gets set to false. So the kernel can look at the register to determine the first predication bit that's high, which indicates which pointer in the vector caused the fault. So it makes sense why the operation will clear the register, but it is annoying to deal with from the programmer's perspective.)
And, of course, the operation can also say "foo = *x;" which means that there also has to be a gather operation. Sure. Something like "*x = *y;" will end up doing both a gather and a scatter.
Copying
Consider a program like:
struct Foo {
x: uint64;
y: uint64;
}
someVariableOfFooType = aFunctionThatReturnsAFoo();
That initializer needs to set both fields inside the Foo. Naively, a compiler might be tempted to use a memcpy() to copy the contents - after all, the contents could be arbitrarily complex, with nested structs. However, that won't work for SIMT, because only some of the threads might be alive at this point in the program. Therefore, that assignment has to only copy the items of the vectors for the threads that are alive; it can't copy the whole vectors because that can clobber other entries in the destination vector which are supposed to persist.
So, all the stores to someVariableOfFooType need to be predicated using the predication registers - we can't naively use a memcpy(). This means that every assignment needs to actually perform n memory operations, where n is the number of leaf types in the struct being assigned - because those memory operations can be predicated correctly using the predication registers. We have to copy structs leaf-by-leaf. This means that the number of instructions to copy a type is proportional to the complexity of the type. Also, both the left side and the right side may be l-values, which means each leaf-copy could actually be a gather/scatter pair of instructions. So, depending on the complexity of the type and the context of the assignment, that single "=" operation might actually generate a huge amount of code.
Pointers (Part 2)
There's one other decision that needs to be made about pointers: Consider:
variable x: uint64;
... &x ...
As I described above, the storage for the variable x is a vector (each thread owns one value in the vector). &x produces a vector of pointers, sure. The question is: should all the pointer values point to the beginning of the x vector? Or should each pointer value point to its own slot inside the x vector? If they point to the beginning, that makes the & operator itself really straightforward: it's just a broadcast instruction. But it also means that the scatter/gather operations get more complicated: they have to offset each pointer by a different amount in order to scatter/gather to the correct place. On the other hand, if each pointer points to its own slot inside x, that means the scatter/gather operations are already set up correctly, but the & operation itself gets more complicated.
Both options will work, but I ended up making all the pointer point to the beginning of x. The reason for that is for programs like:
struct Foo {
x: uint32;
y: uint64;
}
variable x: Foo;
... &x ...
If I picked the other option, and had the pointers point to their own slot inside x, it isn't clear which member of Foo they should be pointing inside of. I could have, like, found the first leaf, and made the pointers point into that, but what if the struct is empty... It's not very elegant.
Also, if I'm assigning to x or something where I need to populate every field, because every copy operation has to copy leaf-by-leaf, I'm going to have to be modifying the pointers to point to each field. If one of the fields is a uint32 and the next one is a uint64, I can't simply just add a constant amount to each pointer to get it to point to its slot in the next leaf. So, if I'm going to be mucking about with individual pointer values for each leaf in a copy operation, I might as well have the original pointer point to the overall x vector rather than individual fields, because pointing to individual fields doesn't actually make anything simpler.
Function Calls
This language supports function pointers, which are callable. This means that you can write a program like this (taken from the test suite):
function helper1(): uint64 ...
function helper2(): uint64 ...
function main(index: uint64): uint64 {
variable x: FunctionPointer<uint64>;
if (index < 3) {
x = helper1;
} else {
x = helper2;
}
return x();
}
Here, that call to x() allows different threads to point to different functions. This is a problem for us, because all the "threads" that are running share the same instruction pointer. We can't actually have some threads call one function and other threads call another function. So, what we have to do instead is to set the predication bitmask to only the "threads" which call one function, then call that function, then set the predication bitmask to the remaining threads, then call the other function. Both functions get called, but the only "threads" that are alive during each call are only the ones that are supposed to actually be running the function.
This is tricky to get right, though, because anything could be in that function pointer vector. Maybe all the threads ended up with the same pointers! Or maybe each thread ended up with a different pointer! You *could* do the naive thing and do something like:
for i in 0 ..< numThreads:
predicationMask = originalPredicationMask & (1 << i)
call function[i]
But this has really atrocious performance characteristics. This means that every call actually calls numThreads functions, one-by-one. But each one of those functions can have more function calls! The execution time will be proportional to numThreads ^ callDepth. Given that function calls are super common, this exponential runtime isn't acceptable.
Instead, what you have to do is gather up and deduplicate function pointers. You need to do something like this instead:
func generateMask(functionPointers, target):
mask = 0;
for i in 0 ..< numThreads:
if functionPointers[i] == target:
mask |= 1 << i;
return mask;
for pointer in unique(functionPointers):
predicationMask = originalPredicationMask & generateMask(functionPointers, pointer)
call pointer
I couldn't find an instruction in the Intel instruction set that did this. This is also a complicated enough algorithm that I didn't want to write this in assembly and have the compiler emit the instructions for it. So, instead, I wrote it in C++, and had the compiler emit code to call this function at runtime. Therefore, this routine can be considered a sort of "runtime library": a function that automatically gets called when the code the author writes does a particular thing (in this case, "does a particular thing" means "calls a function").
Doing it this way means that you don't get exponential runtime. Indeed, if your threads all have the same function pointer value, you get constant runtime. And if the threads diverge, the slowdown will be at most proportional to the number of threads. You'll never run a function where the predication bitmask is 0, which means there is a floor about how slow the worst case can be - it will never get worse than having each thread individually diverge from all the other threads.
Control Flow
As described above, control flow (meaning: if statements, for loops, breaks, continues, and returns) are implemented by changing the value of the predication bitmask register. The x86_64 instruction set has instructions that do this.
There are 2 ways to handle the predication registers. One way is to observe the fact that there are 8 predication registers, and to limit the language to only allow 8 (7? 6? 3?) levels of nested control flow. If you pick this approach, the code that you emit inside each if statement and for loop would use a different predication register. (Sibling if statements can use the same predication register, but nested ones have to use different predication registers.)
I elected to not add this restriction, but instead to save and restore the values of the predication register to the stack. This is slower, but it means that control flow can be nested without limit. So, all the instructions I emit are all predicated on the k1 register - I never use k2 - k7 (except - I use k2 to save/restore the value of k1 during scatter/gather operations because those clobber the value of whichever register you pass into it).
For an "if" statement, you actually need to save 2 predication masks:
- One that saves the predication mask that was incoming to the beginning of the "if" statement. You need to save this so that, after the "if" statement is totally completed, you can restore it back to what it was originally
- If there's an "else" block, you also need to save the bitmask of the threads that should run the "else" block. You might think that you can compute this value at runtime instead of saving/loading it (it would be the inverse of the threads that ran the "then" block, and-ed with the set of incoming threads) but you actually can't do that because break and continue statements might actually need to modify this value. Consider if there's a break statement as a direct child of the "then" block - at the end of the "then" block, there will be no threads executing (because they all executed the "break" statement). If you then use the set of currently executing threads to try to determine which should execute the "else" block, you'll erroneously determine that all threads (even the ones which ran the "then" block!) should run the "else" block. Instead, you need to compute up-front the set of threads should be running the "else" block, save it, and re-load it when starting to execute the "then" block.
For a "for" loop, you also need to save 2 predication masks:
- Again, you need to store the incoming predication mask, to restore it after the loop has totally completed
- You also need to save and restore the set of threads which should execute the loop increment operation at the end of the loop. The purpose of saving and restoring this is so that break statements can modify it. Any thread that executes a break statement needs to remove itself from the set of threads which executes the loop increment. Any thread that executes a continue statement needs to remove itself from executing *until* the loop increment. Again, this is a place where you can't recompute the value at runtime because you don't know which threads will execute break or continue statements.
If you set up "if" statements and "for" loops as above, then break and continue statements actually end up really quite simple. First, you can verify statically that no statement directly follows them - they should be the last statement in their block.
Then, what a break statement does is:
- Find the deepest loop it's inside of, and find all the "if" statements between that loop and the break statement
- For each of the "if" statements:
- Emit code to remove all the currently running threads from both of the saved bitmasks associated with that "if" statement. Any thread that executes a break statement should not run an "else" block and should not come back to life after the "if" statement.
- Emit code to remove all the currently running threads from just the second bitmask associated with the loop. (This is the one that gets restored just before the loop increment operation). Any thread that executes a break statement should not execute the loop increment.
A "continue" statement does the same thing except for the last step (those threads *should* execute the loop increment). And a "return" statement removes all the currently running threads from all bitmasks from every "if" statement and "for" loop it's inside of.
This is kind of interesting - it means an early return doesn't actually stop the function or perform a jmp or ret. The function still continues executing, albeit with a modified predication bitmask, because there might still be some threads "alive." It also means that "if" statements don't actually need to have any jumps in them - in the general case, both the "then" block and the "else" block will be executed, so instead of jumps you can just modify the predication bitmasks - and emit straight-line code. (Of course, you'll want the "then" block and the "else" block to both jump to the end if they find that they start executing with an empty predication bitmask, but this isn't technically necessary - it's just an optimization.)
Shared Variables
When you're using the SIMT approach, one thing that becomes useful is the ability to interact with external memory. GPU threads don't really perform I/O as such, but instead just communicate with the outside world via reading/writing global memory. This is a bit of a problem for SIMT-generated code, because it will assume that the type of everything is vector type - one for each thread. But, when interacting with external memory, all "threads" see the same values - a shared int is just an int, not a vector of ints.
That means we now have a 3rd kind of value classification. Previously, we had l-values and r-values, but l-values can be further split into vector-l-values and scalar-l-values. A pointer type now needs to know statically whether it points to a vector-l-value or a scalar-l-value. (This information needs to be preserved as we pass it from l-value pointers through the & and * operators.) In the language, this looks like "pointer<uint64 | shared>".
It turns out that, beyond the classical type-checking analysis, it's actually pretty straightforward to deal with scalar-l-values. They are actually strictly simpler than vector-l-values.
In the language, you can declare something like:
variable<shared> x: uint64;
x = 4;
which means that it is shared among all the threads. If you then refer to x, that reference expression becomes a scalar-l-value, and produces a vector of pointers, all of which point to x's (shared) storage. The "=" in the "x = 4;" statement now has to be made aware that:
- If the left side is a vector-l-value, then the scatter operation needs to offset each pointer in the vector to point to the specific place inside the destination vectors that the memory operations should write to
- But, if the left side is a scalar-l-value, then no such offset needs to occur. The pointers already point to the one single shared memory location. Everybody points to the right place already.
(And, of course, same thing for the right side of the assignment, which can be either a vector-l-value, a scalar-l-value, or an r-value.)
Comparisons and Booleans
AVX-512 of course has
vector compare instructions. The result of these vector comparisons *isn't* another vector. Instead, you specify one of the bitmask registers to receive the result of the comparison. This is useful if the comparison is the condition of an "if" statement, but it's also reasonable for a language to have a boolean type. If the boolean type is represented as a normal vector holding 0s and 1s, there's an elegant way to convert between the comparison and the boolean.
The comparison instructions look like:
vpcmpleq %zmm1,%zmm0,%k2{%k1}
If you were to speak this aloud, what you'd say is "do a vector packed compare for less-than-or-equal-to on the quadwords in zmm0 and zmm1, put the result in k2, and predicate the whole operation on the value of k1." Importantly, the operation itself is predicated, and the result can be put into a different predication register. This means that, after you execute this thing, you know which threads executed the instruction (because k1 is still there) but you also know the result of the comparison (because it's in k2).
So, what you can do is: use k1 to broadcast a constant 0 into a vector register, and then use k2 to broadcast a constant 1 into the same vector register. This will leave a 1 in all the spots where the test succeeded, and a 0 in all the remaining spots. Pretty cool!
If you want to go the other way, to convert from a boolean to a mask, you can just compare the boolean vector to a broadcasted 0, and compare for "not equal." Pretty straightforward.
Miscellanea
I'm using my own calling convention (and ABI) to pass values into and out of functions. It's for simplicity - the x64 calling convention is kind of complicated if you're using vector registers for everything. One of the most useful decisions I made was to formalize this calling convention by encoding it in a C++ class in the compiler. Rather than having various different parts of the compiler just assume they knew where parameters were stored, it was super useful to create a single source of truth about the layout of the stack at call frame boundaries. I ended up changing the layout a few different times, and having this single point of truth meant that such changes only required updating a single class, rather than a global change all over the compiler.
Inventing my own ABI also means that there will be a boundary, where the harness will have to call the generated code. At this boundary, there has to be a trampoline, where the contents of the stack gets rejiggered to set it up for the generated code to look in the right place for stuff. And, this trampoline can't be implemented in C++, because it has to do things like align the stack pointer register, which you can't do in C++. AVX-512 requires vectors to be loaded and stored at 64-byte alignment, but Windows only requires 16-byte stack alignments. So, in my own ABI I've said "stack frames are all aligned to 64-byte boundaries" which means the trampoline has to enforce this before the entry point can be run. So the trampoline has to be written in assembly.
The scatter/gather operations (which are required for l-values to work) only operate on 32-bit and 64-bit values inside the AVX-512 registers. This means that the only types in the language can be 32-bit and 64-bit types. An AVX-512 vector, which is 512 bits = 64 bytes wide, can hold 8 64-bit values, or 16 32-bit values. However, the entire SIMT programming model requires you to pick up front how many "threads" will be executing at once. If some calculations in your program can calculate 8 values at a time, and some other calculations can calculate 16 values at a time, it doesn't matter - you have to pessimize and only use 8-at-a-time. So, if the language contains 64-bit types, then the max number of "threads" you can run at once is 8. If the language only contains 32-bit types (and you get rid of 64-bit type support, including 64-bit pointers), then you can run 16 "threads" at once. For me, I picked to include 64-bit types and do 8 "threads" at a time, because I didn't want to limit myself to the first 4GB of memory (the natural stack and heap are already farther than 4GB apart from each other in virtual address space, so I'd have to, like, mess with Windows's VM subsystem to allocate my own stack/heap and put them close to each other, and yuck I'll just use 64-bit pointers thankyouverymuch).
Conclusion
And that's kind of it. I learned a lot along the way - there seem to be good reasons why, in many shading languages (which use this SIMT model),
- Support for 8-bit and 16-bit types is rare - the scatter/gather operations might not support them.
- Support for 64-bit types is also rare - the smaller your types, the more parallelism you get, for a particular vector bit-width.
- Memory loads and stores turn into scatter/gather operations instead.
- A sophisticated compiler could optimize this, and turn some of them into vector loads/stores instead.
- This might be why explicit support for pointers is relatively rare in shading languages - no pointers means you can _always_ use vector load/store operations instead of scatter/gather operations (I think).
- You can't treat memory as a big byte array and memcpy() stuff around; instead you need to treat it logically and operate on well-typed fields, so the predication registers can do the right thing.
- Shading languages usually don't have support for function pointers, because calling them ends up becoming a loop (with a complicated pointer coalescing phase, no less) in the presence of non-uniformity. Instead, it's easy for the language to just say "You know what? All calls have to be direct. Them's the rules."
- Pretty much every shading language has a concept of multiple address spaces. The need for them naturally arises when you have local variables which are stored in vectors, but you also need to interact with global memory, which every thread "sees" identically. Address spaces and SIMT are thoroughly intertwined.
- I thought it was quite cool how AVX-512 complimented the existing (scalar) instruction set. E.g. all the math operations in the language use vector operations, but you still use the normal call/ret instructions. You use the same rsp/rbp registers to interact with the stack. The vector instructions can still use the SIB byte. The broadcast instruction broadcasts from a scalar register to a vector register. Given that AVX-512 came out of the Larrabee project, it strikes me as a very Intel-y way to build a GPU instruction set.