In this post I will write about green threads and provide a sample implementation in C++ - or, better say, in C++ mixed with assembly language. The code is highly compiler-dependant (VC++17) and fragile, thus not production-ready. It makes assumptions about stack layout and data in registers and even something as simple as turning on optimizations will most likely crash it. But it was a fun programming exercise for a concept often implemented in asynchronous libraries or found in some programming language constructs. Interesting links: Fibers, Coroutines, Set Context, Actors, Cooperative multitasking

The problem

You have N CPUs and M threads, where M » N - the case for actors, for instance. You know that these M threads will most likely interact with each other (either passing messages or synchronizing on some primitive). You might also have I/O and you want to keep the program flow simple and not pollute it with lots of callbacks - synchronous I/O code is clearly much simpler to read than async I/O, but also more heavy on the OS. These problems are suitable for considering a cooperative scheduling approach.

vs_run

The API and the test scenario

For this demo I will only implement wait-for (waiting for a thread to finish) and yield (give control to the next thread). Any other operations would only make the code more complex to read but not add much clarity into the concepts. Here is my test bed:

void fn(thread_pool::thread_ctx *ctx, int p1, int p2) {

	if (p1 >= 4) {
		cout << "Ended NO yields: " << ctx->name() << endl;
		return;
	}
	
	char thread_name[80];
	sprintf_s(thread_name, "Thread %d - %d", p1 + 1, p2 + 2);
	cout << "Starts: " << thread_name << endl;

	auto child1 = ctx->call_fn(100000, thread_name, fn, p1 + 1, p2 + 2);

	sprintf_s(thread_name, "Secondary thread %d - %d", p1 + 1, p2 + 2);
	cout << "Starts: " << thread_name << endl;

	auto child2 = ctx->call_fn(100000, thread_name, fn, p1 + 1, p2 + 2);

	std::vector<decltype(child1)> children;

	for (int i = 0; i < 7; i++) {

		cout << p1 << p2 << endl;

		auto p = ctx->call_fn(10000, "Child of child", fn, p1 + 1, p2 + 2);

		children.push_back(p);
		ctx->yield();
	}

	ctx->wait_for(child1);
	ctx->wait_for(child2);

	for (auto &c : children)
		ctx->wait_for(c);

	cout << "Ended WITH yields: " << ctx->name() << endl;
}

int main()
{
	int p1 = 0;
	int p2 = 0;

	thread_pool pool;
	pool.call_fn(100000, "THREAD_0" ,fn, p1, p2);	
	return 0;
}

In main I simply call on the main thread, which I named “THREAD_0”, a procedure, fn, defined above. This is (and should be) a synchronous call. fn should wait for all spawned threads to be finished. Note: when I say “thread”, I mean a “green thread”. All my threads share the same OS thread. In this sample there is no OS multithreading involved. From the OS perspective the application is single threaded.

The fn function spawns recursively many other child-threads which are waited for at the end of the function. Context switching happens when control reaches the ctx->yield() call. ctx is our current thread (an instance of the thread_context class). If control does not reach a ctx->yield(), the function is simply executed on the stack of the caller, synchronously, like a normal function. When control reaches ctx->yield() the function will be put on hold and the rest of the threads will be executed. On the spawning thread, control is given back to the caller, asynchronously, not waiting for the new thread to finish. The result is a thread_ctx object which can be queried for work completion or waited for. When a thread is spawned, parameters can be sent to it like to any other function, on the stack - the implementation is based on a variadic template. In our case I send two ints.

API description:

template<class fn, typename... T> void thread_pool::call_fn(unsigned int stack_size, const char* name, fn* f, T... params) - creates the parent thread. It is the way to initialize the threading library because for the parent thread return information should be stored in the main process stack. The function should wait for all its children to finish before exiting.

template<class fn, typename... T> shared_ptr<thread_ctx> thread_pool::thread_ctx::call_fn(unsigned int stack_size, const char* name ,fn* f, T... params) - spawns a new child thread.

void thread_pool::thread_ctx::wait_for(shared_ptr<thread_ctx>& other) - waits for a child.

void thread_pool::thread_ctx::yield() - gives control to the next thread waiting for execution in a round-robin fashion.

Implementation

Here is the full listing:

#define _DEBUG

#pragma optimize ("", off)

class thread_pool {

private:
	unsigned int current_thread = 0;	

public:

	struct thread_ctx {

	private:
		char* thread_name = nullptr;

	private:
		unsigned char* stack = nullptr;
		unsigned char* stack_ptr = nullptr;
		unsigned char* continuation_location = nullptr;

		thread_pool* th_p = nullptr;
		bool	b_finished = false;
		
		// for when we have an out of order return from the function.
		bool	b_yielded = false; 

#ifdef _DEBUG
		int stack_size = 0;
#endif

	public: 
		friend class thread_pool;
		thread_ctx(thread_pool* tp = nullptr, int _stack_size = -1) : th_p(tp){

#ifdef _DEBUG
			stack_size = _stack_size;
#endif
			stack = new unsigned char[_stack_size];
			stack_ptr = stack + _stack_size;
		}

		~thread_ctx() {
			delete[] stack;
			// cout << "Deleted: " << thread_name << endl;
			if(thread_name)
				free(thread_name);
		}

		thread_pool* get_thread_pool() { return th_p; }

		const char* name() const { return thread_name; }

		bool finished() const { return b_finished; }

		void wait_for(shared_ptr<thread_ctx>& other) {
			while (!other->b_finished)
				yield();
		}

		void yield() {

			int stk_tst = 0;
			b_yielded = true;

#ifdef _DEBUG
			// test we are on our stack
			assert( int(&stk_tst) >= int(stack) && int(&stk_tst) < stack_size + int(stack));
#endif
			
			unsigned char** loc_ptr		= &continuation_location;
			auto next			= th_p->next();

			while (next->b_finished || 
				(next.get() != this && next->continuation_location == nullptr))
				next = th_p->next();

			auto next_ptr = next.get();

			// cout << "Yield: From " << this->thread_name << " To " 
			// << next_ptr->thread_name << endl;

			// save the jump location for this thread
			__asm {				
				mov eax, offset continuation_code_ptr;
				mov ebx, loc_ptr;
				mov [ebx], eax;
			}

			// save the current stack_ptr to restore it
			__asm{
				mov ecx, this;
				push stk_tst;
				push ebp;
				mov [ecx + stack_ptr], esp;
			}			
			
			// do the jump to to next location,
			__asm {
				mov ecx, next_ptr;
				mov ebx, [ecx + continuation_location]; // next is in ecx
				mov esp, [ecx + stack_ptr];
				jmp ebx;			
			}

			assert(0); // should never get here
			
			// jmp location:
			__asm
			{
			continuation_code_ptr:
				pop ebp;
				pop stk_tst; // should be 0;
			}

			assert(stk_tst == 0);			
		}

		template<class fn, typename... T> 
		shared_ptr<thread_ctx> 
		call_fn(unsigned int stack_size, const char* name ,fn* f, T... params) {
			auto ctx = make_shared<thread_ctx>(th_p, stack_size);

			if (name) 
				ctx->thread_name = _strdup(name);

			th_p->threads.push_back(ctx);
			ctx->assign_fn(this, f, params...);
			return ctx;
		}

	private:

		template<class fn, typename... T> 
		void assign_fn(thread_ctx* parent_thread, fn* f, T... params) {
			int stk_tst = 0;

			unsigned char* stk = stack_ptr;	

			if (parent_thread == nullptr) { // we are on root of threads
				// setup start of new thread
				__asm {
					mov eax, esp; // save old stack on new stack
					mov esp, stk;
					push eax;
				}

				f(this, params...);

				__asm {
					pop esp;
				}

				b_finished = true;
				th_p->remove(this);
			}
			else {
				unsigned char** loc_ptr		= &parent_thread->continuation_location;
				unsigned char*	stck_ptr	= nullptr;
				
				// save the jump location 
				__asm {
					mov eax, offset continuation_code_ptr;
					mov ebx, loc_ptr;
					mov[ebx], eax;
				}

				// save the current stack_ptr to restore it (we are on parent stack)
				__asm {
					push stk_tst;
					push ebp;
					mov[stck_ptr], esp;
				}

				parent_thread->stack_ptr = stck_ptr;

				// switch to the new thread stack
				__asm {
					mov esp, stk;
					push stck_ptr; // these still work because we have not changed ebp
					mov eax, this;
					push eax;
				}

				f(this, params...);

				__asm {
					pop ecx; // this is in ecx
					mov eax, [ecx + b_yielded];
					and al, 1
					jnz function_already_returned_once;
				}

				// switch back to parent stack
				// restore this
				__asm {
					pop eax; // stk_ptr
					mov esp, eax; // stk_ptr
					pop ebp;					
					pop stk_tst; // here we already have the values
				}

				assert(stk_tst == 0);

				b_finished = true;
				th_p->remove(this);

				return;

				__asm {
				function_already_returned_once:

					// this is in ecx
					mov eax, [ecx + b_finished];
					or al, 1;
					mov[ecx + b_finished], eax; // set the b_finished flag to true

					mov ebx, ecx; // save this temporary
					mov ecx, [ebx + th_p];

					push ebx; // once for saving, once for remove call
					push ebx;
					call remove;

					pop ecx;
					call yield

					// TODO: set the b_finished flag and then remove the thread
					// do yield
				}

				assert(0); // should never get here.

				__asm
				{
				continuation_code_ptr: // only for jmp code from another place
					pop ebp;
					pop stk_tst; // should be 0;
				}

				assert(stk_tst == 0);
				// cout << "Async return for " << this->thread_name << endl;
			}

		}			
	};

private:
	std::vector<shared_ptr<thread_ctx>> threads;

public:

	shared_ptr<thread_ctx> next() {
		for( unsigned th = 0; th < threads.size() ; th ++){
			auto ret = threads[(++current_thread) % threads.size()];
			if (ret != nullptr)
				return ret;
		} 

		return nullptr;
	}

	void remove(const thread_ctx* ctx) {
		for(unsigned int i = 0 ; i < threads.size(); i++)
			if (threads[i].get() == ctx) {
				threads[i] = threads[threads.size() - 1];
				threads.pop_back();
				return;
			}
		assert(0); // not found
	}

	~thread_pool() {
		for (unsigned int i = 0; i < threads.size(); i++) {
			threads[i] = nullptr;			
		}
	}

	template<class fn, typename... T> 
	void call_fn(unsigned int stack_size, const char* name, fn* f, T... params) {
		auto ctx = make_shared<thread_ctx>(this, stack_size);

		if(name) ctx->thread_name = _strdup(name);

		threads.push_back(ctx);
		ctx->assign_fn(nullptr, f, params...);
	}
};

Explanations:

struct thread_ctx {

	private:
		char* thread_name = nullptr;

	private:
		unsigned char* stack = nullptr;
		unsigned char* stack_ptr = nullptr;
		unsigned char* continuation_location = nullptr;

		thread_pool* th_p = nullptr;
		bool	b_finished = false;
		// for when we have an out of order return from the function.
		bool	b_yielded = false; 
		
		...
}
  • thread_ctx - our thread.
  • thread_name - a private variable used for debug - helps identifying our thread.
  • stack - the new-allocated stack for our thread.
  • stack_ptr - stack pointer (esp). When a context switch is performed, esp is saved / restored to / from this variable.
  • continuation_location - pointer to the code location from which the thread was interrupted. It is saved and jumped to on context switch.
  • thread_pool - the parent thread, pool,
  • b_finished set if the thread has finished its work.
  • b_yielded an internal status set to true if the thread was ever interrupted. This is critical because, if the thread has been interrupted, the invocation will return asynchronously to the parent through an out-of-order return (void assign_fn(thread_ctx* parent_thread, fn* f, T... params)). The function will then continue, without another return to the caller. When it finishes, it just forwards control to the next waiting thread (a final call to yield()) and sets the b_finished flag.

Some compiler specific assumptions:

  • ebp is the root of the function stack./Oy will crash our code. This restriction could be removed through some extra push / pops.
  • ecx stores the this pointer (thiscall convention assumed). To access any members from this we do mov ..., [ecx + var_name]
  • inline asm assumes VC++ on 32 bits.

Any optimizations that somehow assume the threads continue without interruptions will crash our code.

Comments

  • I wrote this code for my own fun, with no other intention except getting it to work.
  • Assembly language is extremely powerful and suited for such hacks. Optimizations are usually better left off to the compiler, but doing magic and jumping between functions? Well, there are C APIs already battletested for this, but it is fun to play around and provide a possible implementation for how these APIs work behind the scenes.
  • Debugging jumps between functions is relatively hard as there is little information availabe. I coded incrementally, getting each line of code working, one by one.