#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include "thread.h"
#include "synch.h"

bool *seat_taken;
struct lock *seat_taken_lock;
int *seat_used_count;
int num_seats;

static bool grab_seat(int *taken) {
    for (int i = 0; i < num_seats; i++) {
        lock_acquire(&seat_taken_lock[i]);
        bool take = seat_taken[i] == false;
        if (take) {
            seat_taken[i] = true;
            *taken = i;
        }
        lock_release(&seat_taken_lock[i]);
        if (take)
            return true;
    }
    return false;
}

int seat_acquire(void) {
    int taken = -1;
    while (!grab_seat(&taken))
        ;
    return taken;
}

void seat_release(int id) {
    lock_acquire(&seat_taken_lock[id]);
    seat_taken[id] = false;
    lock_release(&seat_taken_lock[id]);
}

void init_seats(int count) NO_STEP {
    seat_taken = malloc(sizeof(bool) * count);
    seat_taken_lock = malloc(sizeof(struct lock) * count);
    seat_used_count = malloc(sizeof(int) * count);
    num_seats = count;
    for (int i = 0; i < count; i++) {
        seat_taken[i] = false;
        lock_init(&seat_taken_lock[i]);
        seat_used_count[i] = 0;
    }
}

struct semaphore threads_done;

void thread_fn() {
    int seat = seat_acquire();
    seat_used_count[seat]++;
    seat_release(seat);

    sema_up(&threads_done);
}

int main(void) {
    sema_init(&threads_done, 0);
    init_seats(2);

    thread_new(&thread_fn);
    thread_new(&thread_fn);
    thread_fn();

    for (int i = 0; i < 3; i++)
        sema_down(&threads_done);

    int total = 0;
    for (int i = 0; i < num_seats; i++) {
        total += seat_used_count[i];
    }
    assert(total == 3);

    return 0;
}
