#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <termios.h>
#include <unistd.h>

//#include "screen.h"

struct termios orig_termios;
int raw_enabled = 0;

static void disable_raw_mode() {
  if (!raw_enabled) return;
  tcsetattr(STDIN_FILENO, TCSAFLUSH, &orig_termios);
}

static void enable_raw_mode() {
  if (raw_enabled) return;
  raw_enabled = 1;
  tcgetattr(STDIN_FILENO, &orig_termios);
  atexit(disable_raw_mode);
  struct termios raw = orig_termios;
  raw.c_lflag &= ~(ECHO | ICANON);
  raw.c_cc[VMIN] = 1;
  tcsetattr(STDIN_FILENO, TCSANOW, &raw);
}

static void query_cursor(int *line, int *col) {
  const char query[] = "\x1b[6n";
  char answer[16];
  int len;

  *line = -1;
  *col = -1;
  write(STDOUT_FILENO, query, sizeof(query));
  len = read(STDIN_FILENO, answer, sizeof(answer));
  if (len >= 2 && answer[0] == '\x1b' && answer[1] == '[') {
    sscanf(answer + 2, "%d;%dR", line, col);
  }
}

static void move_cursor(int line, int col) {
  dprintf(STDOUT_FILENO, "\x1b[%d;%df", line, col);
}

static void query_term_size(int *line, int *col) {
  int curl, curc;
  query_cursor(&curl, &curc);
  move_cursor(999, 999);
  query_cursor(line, col);
  move_cursor(curl, curc);
}

static void clear_screen() {
  const char query[] = "\x1b[2J";
  write(STDOUT_FILENO, query, sizeof(query));
}

static void clear_line() {
  const char query[] = "\x1b[2K";
  write(STDOUT_FILENO, query, sizeof(query));
}

struct input_line {
  char *line;
  size_t len;
  size_t pos;
};

struct screen {
  struct input_line *input;
  int lines;
  int cols;
  int cur_line;
  int cur_col;
  char *prompt;
  pthread_mutex_t lock;
};
void screen_move_cursor(struct screen *screen, int l, int c) {
  move_cursor(l, c);
  screen->cur_line = l;
  screen->cur_col = c;
}
void screen_free(struct screen *screen) {
  free(screen->input->line);
  free(screen->input);
  free(screen->prompt);
  pthread_mutex_destroy(&screen->lock);
  free(screen);
}

static void append_char(struct screen *screen, int c) {
  struct input_line *input = screen->input;
  size_t old_len = input->len;
  if (input->pos >= old_len) {
    input->len = old_len * 2 + 1;
    input->line = realloc(input->line, input->len);
    memset(input->line + input->pos, 0, input->len - old_len);
  }
  input->line[input->pos] = c;
  input->pos++;
}

static void delete_char(struct screen *screen) {
  struct input_line *input = screen->input;
  if (input->pos > 0) {
    input->pos--;
    input->line[input->pos] = 0;
  }
}

static void clear_input(struct screen *screen) {
  free(screen->input->line);
  screen->input->line = calloc(1, 1);
  screen->input->len = 1;
  screen->input->pos = 0;
}

static ssize_t min(ssize_t a, ssize_t b) {
  if (a <= b) return a;
  return b;
}

void redraw_input(struct screen *screen) {
  pthread_mutex_lock(&screen->lock);
  screen_move_cursor(screen, screen->lines, 1);
  clear_line();
  size_t pr_size = dprintf(STDOUT_FILENO, "%s", screen->prompt);
  size_t pos = pr_size + screen->input->pos;
  size_t right = min(pos, screen->cols - 2);
  if (screen->input->pos > 0) {
    char *cur_input = screen->input->line;
    cur_input = cur_input + (pos - right);
    dprintf(STDOUT_FILENO, "%s", cur_input);
  }
  screen_move_cursor(screen, screen->lines, right + 1);
  pthread_mutex_unlock(&screen->lock);
}

char *screen_read_input(struct screen *screen) {
  int c;
  char *res;
  while ((c = getchar()) != '\n' || screen->input->pos == 0) {
    if (c == 127)
      delete_char(screen);
    else if (c == '\x1b') {
      append_char(screen, '^');
      append_char(screen, '[');
    } else if (c != '\n')
      append_char(screen, c);
    redraw_input(screen);
  };
  res = strndup(screen->input->line, screen->input->pos + 1);
  clear_input(screen);
  redraw_input(screen);

  return res;
}

static void print_separator(struct screen *screen) {
  pthread_mutex_lock(&screen->lock);
  screen_move_cursor(screen, screen->lines - 1, 1);
  for (int i = 0; i < screen->cols; i++) {
    write(STDOUT_FILENO, "=", 1);
  }
  pthread_mutex_unlock(&screen->lock);
}

void screen_query_cursor(struct screen *screen, int *l, int *c) {
  if (screen->cur_line < 0 || screen->cur_col < 0) {
    query_cursor(&screen->cur_line, &screen->cur_col);
  }
  *l = screen->cur_line;
  *c = screen->cur_col;
}
void screen_print_line(struct screen *screen, int line, char *msg) {
  pthread_mutex_lock(&screen->lock);
  int oldl, oldc;
  if (line >= 1 && line <= screen->lines - 2) {
    screen_query_cursor(screen, &oldl, &oldc);
    screen_move_cursor(screen, line, 1);
    clear_line();
    for (int i = 0; i < screen->cols && msg[i] != 0; i++) {
      write(STDOUT_FILENO, msg + i, 1);
    }
    screen_move_cursor(screen, oldl, oldc);
  }
  pthread_mutex_unlock(&screen->lock);
}

void screen_get_size(struct screen *screen, int *lines, int *cols) {
  *lines = screen->lines;
  *cols = screen->cols;
}

struct screen *screen_init(const char *prompt) {
  enable_raw_mode();
  clear_screen();
  struct screen *screen = malloc(sizeof(struct screen));
  screen->input = malloc(sizeof(struct input_line));
  screen->input->len = 1;
  screen->input->pos = 0;
  screen->cur_line = -1;
  screen->cur_col = -1;
  screen->input->line = calloc(1, 1);
  screen->prompt = prompt == NULL ? NULL : strdup(prompt);
  query_term_size(&screen->lines, &screen->cols);
  pthread_mutex_init(&screen->lock, NULL);
  print_separator(screen);
  redraw_input(screen);
  return screen;
}
