diff --git a/Makefile b/Makefile index bdf94b2..b57528a 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,9 @@ CC=cc -CFLAGS=-c -Wall -Wextra -fPIC -LDFLAGS=-lreadline +CFLAGS=-c -Wall -Wextra -Werror -fPIC -ggdb -O2 -fno-omit-frame-pointer -fno-strict-aliasing -rdynamic +LDFLAGS=-lreadline -lssl -lcrypto -lrt -lz -ggdb -rdynamic LD=cc -SRC=main.c loop.c interface.c +SRC=main.c loop.c interface.c net.c mtproto-common.c mtproto-client.c queries.c structures.c OBJ=$(SRC:.c=.o) EXE=telegram diff --git a/interface.c b/interface.c index f43a7bf..472df60 100644 --- a/interface.c +++ b/interface.c @@ -4,10 +4,11 @@ #include #include - +#include #include #include #include "include.h" +#include "queries.h" char *default_prompt = ">"; char *get_default_prompt (void) { @@ -21,11 +22,13 @@ char *complete_none (const char *text UU, int state UU) { char *commands[] = { "help", "msg", + "contact_list", 0 }; int commands_flags[] = { 070, 072, + 00, }; char *a = 0; @@ -151,5 +154,27 @@ char **complete_text (char *text, int start UU, int end UU) { } void interpreter (char *line UU) { - assert (0); + if (!memcmp (line, "contact_list", 12)) { + do_update_contact_list (); + } +} + +void rprintf (const char *format, ...) { + + int saved_point = rl_point; + char *saved_line = rl_copy_text(0, rl_end); + rl_save_prompt(); + rl_replace_line("", 0); + rl_redisplay(); + + va_list ap; + va_start (ap, format); + vfprintf (stdout, format, ap); + va_end (ap); + + rl_restore_prompt(); + rl_replace_line(saved_line, 0); + rl_point = saved_point; + rl_redisplay(); + free(saved_line); } diff --git a/interface.h b/interface.h index 763dde8..a0d2376 100644 --- a/interface.h +++ b/interface.h @@ -4,4 +4,6 @@ char *get_default_prompt (void); char *complete_none (const char *text, int state); char **complete_text (char *text, int start, int end); void interpreter (char *line); + +void rprintf (const char *format, ...) __attribute__ ((format (printf, 1, 2))); #endif diff --git a/loop.c b/loop.c index 25c6822..f5d515a 100644 --- a/loop.c +++ b/loop.c @@ -9,62 +9,212 @@ #include #include +#include +#include +#include +#include +#include #include "interface.h" +#include "net.h" +#include "mtproto-client.h" +#include "mtproto-common.h" +#include "queries.h" +#include "telegram.h" + extern char *default_username; extern char *auth_token; void set_default_username (const char *s); +int default_dc_num; +void net_loop (int flags, int (*is_end)(void)) { + while (!is_end ()) { + struct pollfd fds[101]; + int cc = 0; + if (flags & 1) { + fds[0].fd = 0; + fds[0].events = POLLIN; + cc ++; + } -int main_loop (void) { - fd_set inp, outp; - struct timeval tv; - while (1) { - FD_ZERO (&inp); - FD_ZERO (&outp); - FD_SET (0, &inp); - tv.tv_sec = 1; - tv.tv_usec = 0; - - int lfd = 0; - - if (select (lfd + 1, &inp, &outp, NULL, &tv) < 0) { - if (errno == EINTR) { - /* resuming from interrupt, so not an error situation, - this generally happens when you suspend your - messenger with "C-z" and then "fg". This is allowed " - */ + int x = connections_make_poll_array (fds + cc, 101 - cc) + cc; + double timer = next_timer_in (); + if (timer > 1000) { timer = 1000; } + if (poll (fds, x, timer) < 0) { + /* resuming from interrupt, so not an error situation, + this generally happens when you suspend your + messenger with "C-z" and then "fg". This is allowed " + */ + if (flags & 1) { rl_reset_line_state (); rl_forced_update_display (); - continue; } - perror ("select()"); - break; + work_timers (); + continue; } - - if (FD_ISSET (0, &inp)) { + work_timers (); + if ((flags & 1) && (fds[0].revents & POLLIN)) { rl_callback_read_char (); } + connections_poll_result (fds + cc, x - cc); } +} + +int ret1 (void) { return 0; } + +int main_loop (void) { + net_loop (1, ret1); return 0; } -int loop (void) { - size_t size = 0; - char *user = default_username; - if (!user && !auth_token) { - printf ("Telephone number (with '+' sign): "); - if (getline (&user, &size, stdin) == -1) { - perror ("getline()"); - exit (EXIT_FAILURE); - } - user[strlen (user) - 1] = '\0'; - set_default_username (user); +struct dc *DC_list[MAX_DC_ID + 1]; +struct dc *DC_working; +int dc_working_num; +int auth_state; +char *get_auth_key_filename (void); +int zero[512]; + + +void write_dc (int auth_file_fd, struct dc *DC) { + assert (write (auth_file_fd, &DC->port, 4) == 4); + int l = strlen (DC->ip); + assert (write (auth_file_fd, &l, 4) == 4); + assert (write (auth_file_fd, DC->ip, l) == l); + if (DC->flags & 1) { + assert (write (auth_file_fd, &DC->auth_key_id, 8) == 8); + assert (write (auth_file_fd, DC->auth_key, 256) == 256); + } else { + assert (write (auth_file_fd, zero, 256 + 8) == 256 + 8); } + + assert (write (auth_file_fd, &DC->server_salt, 8) == 8); +} + +void write_auth_file (void) { + int auth_file_fd = open (get_auth_key_filename (), O_CREAT | O_RDWR, S_IRWXU); + assert (auth_file_fd >= 0); + int x = DC_SERIALIZED_MAGIC; + assert (write (auth_file_fd, &x, 4) == 4); + x = MAX_DC_ID; + assert (write (auth_file_fd, &x, 4) == 4); + assert (write (auth_file_fd, &dc_working_num, 4) == 4); + assert (write (auth_file_fd, &auth_state, 4) == 4); + int i; + for (i = 0; i <= MAX_DC_ID; i++) { + if (DC_list[i]) { + x = 1; + assert (write (auth_file_fd, &x, 4) == 4); + write_dc (auth_file_fd, DC_list[i]); + } else { + x = 0; + assert (write (auth_file_fd, &x, 4) == 4); + } + } + close (auth_file_fd); +} + +void read_dc (int auth_file_fd, int id) { + int port = 0; + assert (read (auth_file_fd, &port, 4) == 4); + int l = 0; + assert (read (auth_file_fd, &l, 4) == 4); + assert (l >= 0); + char *ip = malloc (l + 1); + assert (read (auth_file_fd, ip, l) == l); + ip[l] = 0; + struct dc *DC = alloc_dc (id, ip, port); + assert (read (auth_file_fd, &DC->auth_key_id, 8) == 8); + assert (read (auth_file_fd, &DC->auth_key, 256) == 256); + assert (read (auth_file_fd, &DC->server_salt, 8) == 8); + if (DC->auth_key_id) { + DC->flags |= 1; + } +} + +void empty_auth_file (void) { + struct dc *DC = alloc_dc (1, strdup (TG_SERVER), 443); + assert (DC); + dc_working_num = 1; + write_auth_file (); +} + +void read_auth_file (void) { + int auth_file_fd = open (get_auth_key_filename (), O_CREAT | O_RDWR, S_IRWXU); + if (auth_file_fd < 0) { + empty_auth_file (); + } + assert (auth_file_fd >= 0); + int x; + if (read (auth_file_fd, &x, 4) < 4 || x != DC_SERIALIZED_MAGIC) { + close (auth_file_fd); + empty_auth_file (); + return; + } + assert (read (auth_file_fd, &x, 4) == 4); + assert (x >= 0 && x <= MAX_DC_ID); + assert (read (auth_file_fd, &dc_working_num, 4) == 4); + assert (read (auth_file_fd, &auth_state, 4) == 4); + int i; + for (i = 0; i <= x; i++) { + int y; + assert (read (auth_file_fd, &y, 4) == 4); + if (y) { + read_dc (auth_file_fd, i); + } + } + close (auth_file_fd); +} + +int loop (void) { + on_start (); + read_auth_file (); + assert (DC_list[dc_working_num]); + DC_working = DC_list[dc_working_num]; + if (!DC_working->auth_key_id) { + dc_authorize (DC_working); + } else { + dc_create_session (DC_working); + } + if (!auth_state) { + if (!default_username) { + size_t size = 0; + char *user = 0; + + if (!user && !auth_token) { + printf ("Telephone number (with '+' sign): "); + if (getline (&user, &size, stdin) == -1) { + perror ("getline()"); + exit (EXIT_FAILURE); + } + user[strlen (user) - 1] = 0; + set_default_username (user); + } + } + do_send_code (default_username); + char *code = 0; + size_t size = 0; + printf ("Code from sms: "); + while (1) { + if (getline (&code, &size, stdin) == -1) { + perror ("getline()"); + exit (EXIT_FAILURE); + } + code[strlen (code) - 1] = 0; + if (do_send_code_result (code) >= 0) { + break; + } + printf ("Invalid code. Try again: "); + } + auth_state = 1; + } + + write_auth_file (); fflush (stdin); + fflush (stdout); + fflush (stderr); rl_callback_handler_install (get_default_prompt (), interpreter); rl_attempted_completion_function = (CPPFunction *) complete_text; diff --git a/loop.h b/loop.h index d70a3e6..88b38ee 100644 --- a/loop.h +++ b/loop.h @@ -1,4 +1,6 @@ #ifndef __LOOP_H__ #define __LOOP_H__ int loop (void); +void net_loop (int flags, int (*end)(void)); +void write_auth_file (void); #endif diff --git a/main.c b/main.c index 8de3c38..5d5c557 100644 --- a/main.c +++ b/main.c @@ -27,14 +27,18 @@ #include #include #include +#include +#include #include "loop.h" +#include "mtproto-client.h" #define PROGNAME "telegram-client" #define VERSION "0.01" #define CONFIG_DIRECTORY ".telegram/" #define CONFIG_FILE CONFIG_DIRECTORY "config" +#define AUTH_KEY_FILE CONFIG_DIRECTORY "auth" #define DOWNLOADS_DIRECTORY "downloads/" #define CONFIG_DIRECTORY_MODE 0700 @@ -72,6 +76,13 @@ void get_terminal_attributes (void) { old_lflag = term.c_lflag; old_vtime = term.c_cc[VTIME]; } + +void set_terminal_attributes (void) { + if (tcsetattr (STDIN_FILENO, 0, &term) < 0) { + perror ("tcsetattr()"); + exit (EXIT_FAILURE); + } +} /* }}} */ char *get_home_directory (void) { @@ -107,6 +118,15 @@ char *get_config_filename (void) { return config_filename; } +char *get_auth_key_filename (void) { + char *auth_key_filename; + int length = strlen (get_home_directory ()) + strlen (AUTH_KEY_FILE) + 2; + + auth_key_filename = (char *) calloc (length, sizeof (char)); + sprintf (auth_key_filename, "%s/" AUTH_KEY_FILE, get_home_directory ()); + return auth_key_filename; +} + char *get_downloads_directory (void) { char *downloads_directory; @@ -149,6 +169,11 @@ void running_for_first_time (void) { exit (EXIT_FAILURE); } close (config_file_fd); + int auth_file_fd = open (get_auth_key_filename (), O_CREAT | O_RDWR, S_IRWXU); + int x = -1; + assert (write (auth_file_fd, &x, 4) == 4); + close (auth_file_fd); + printf ("[%s] created\n", config_filename); /* create downloads directory */ @@ -170,13 +195,26 @@ void usage (void) { exit (1); } +extern char *rsa_public_key_name; +extern int verbosity; +extern int default_dc_num; + void args_parse (int argc, char **argv) { int opt = 0; - while ((opt = getopt (argc, argv, "u:h")) != -1) { + while ((opt = getopt (argc, argv, "u:hk:vn:")) != -1) { switch (opt) { case 'u': set_default_username (optarg); break; + case 'k': + rsa_public_key_name = strdup (optarg); + break; + case 'v': + verbosity ++; + break; + case 'n': + default_dc_num = atoi (optarg); + break; case 'h': default: usage (); @@ -185,7 +223,23 @@ void args_parse (int argc, char **argv) { } } +void print_backtrace (void) { + void *buffer[255]; + const int calls = backtrace (buffer, sizeof (buffer) / sizeof (void *)); + backtrace_symbols_fd (buffer, calls, 1); + exit(EXIT_FAILURE); +} + +void sig_handler (int signum) { + set_terminal_attributes (); + printf ("signal %d received\n", signum); + print_backtrace (); +} + + int main (int argc, char **argv) { + signal (SIGSEGV, sig_handler); + signal (SIGABRT, sig_handler); running_for_first_time (); get_terminal_attributes (); diff --git a/mtproto-client.c b/mtproto-client.c new file mode 100644 index 0000000..5a6802b --- /dev/null +++ b/mtproto-client.c @@ -0,0 +1,864 @@ +#define _FILE_OFFSET_BITS 64 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "net.h" +#include "include.h" +#include "queries.h" +#include "loop.h" + +#define sha1 SHA1 + +#include "mtproto-common.h" + +#define MAX_NET_RES (1L << 16) + +int verbosity; +int auth_success; +enum dc_state c_state; +char nonce[256]; +char new_nonce[256]; +char server_nonce[256]; + +int rpc_execute (struct connection *c, int op, int len); +int rpc_becomes_ready (struct connection *c); +int rpc_close (struct connection *c); + +struct connection_methods auth_methods = { + .execute = rpc_execute, + .ready = rpc_becomes_ready, + .close = rpc_close +}; + +long long precise_time; +long long precise_time_rdtsc; +double get_utime (int clock_id) { + struct timespec T; + #if _POSIX_TIMERS + assert (clock_gettime (clock_id, &T) >= 0); + double res = T.tv_sec + (double) T.tv_nsec * 1e-9; + #else + #error "No high-precision clock" + double res = time (); + #endif + if (clock_id == CLOCK_REALTIME) { + precise_time = (long long) (res * (1LL << 32)); + precise_time_rdtsc = rdtsc (); + } + return res; +} + + + +#define STATS_BUFF_SIZE (64 << 10) +int stats_buff_len; +char stats_buff[STATS_BUFF_SIZE]; + +#define MAX_RESPONSE_SIZE (1L << 24) + +char Response[MAX_RESPONSE_SIZE]; +int Response_len; + +/* + * + * STATE MACHINE + * + */ + +char *rsa_public_key_name = "id_rsa.pub"; +RSA *pubKey; +long long pk_fingerprint; + +static int rsa_load_public_key (const char *public_key_name) { + pubKey = NULL; + FILE *f = fopen (public_key_name, "r"); + if (f == NULL) { + fprintf (stderr, "Couldn't open public key file: %s\n", public_key_name); + return -1; + } + pubKey = PEM_read_RSAPublicKey (f, NULL, NULL, NULL); + fclose (f); + if (pubKey == NULL) { + fprintf (stderr, "PEM_read_RSAPublicKey returns NULL.\n"); + return -1; + } + + return 0; +} + + + + + +int auth_work_start (struct connection *c); + +/* + * + * UNAUTHORIZED (DH KEY EXCHANGE) PROTOCOL PART + * + */ + +BIGNUM dh_prime, dh_g, g_a, dh_power, auth_key_num; +char s_power [256]; + +struct { + long long auth_key_id; + long long out_msg_id; + int msg_len; +} unenc_msg_header; + + +#define ENCRYPT_BUFFER_INTS 16384 +int encrypt_buffer[ENCRYPT_BUFFER_INTS]; + +#define DECRYPT_BUFFER_INTS 16384 +int decrypt_buffer[ENCRYPT_BUFFER_INTS]; + +int encrypt_packet_buffer (void) { + return pad_rsa_encrypt ((char *) packet_buffer, (packet_ptr - packet_buffer) * 4, (char *) encrypt_buffer, ENCRYPT_BUFFER_INTS * 4, pubKey->n, pubKey->e); +} + +int encrypt_packet_buffer_aes_unauth (const char server_nonce[16], const char hidden_client_nonce[32]) { + init_aes_unauth (server_nonce, hidden_client_nonce, AES_ENCRYPT); + return pad_aes_encrypt ((char *) packet_buffer, (packet_ptr - packet_buffer) * 4, (char *) encrypt_buffer, ENCRYPT_BUFFER_INTS * 4); +} + + +int rpc_send_packet (struct connection *c) { + int len = (packet_ptr - packet_buffer) * 4; + c->out_packet_num ++; + long long next_msg_id = (long long) ((1LL << 32) * get_utime (CLOCK_REALTIME)) & -4; + if (next_msg_id <= unenc_msg_header.out_msg_id) { + unenc_msg_header.out_msg_id += 4; + } else { + unenc_msg_header.out_msg_id = next_msg_id; + } + unenc_msg_header.msg_len = len; + + int total_len = len + 20; + assert (total_len > 0 && !(total_len & 0xfc000003)); + total_len >>= 2; + if (total_len < 0x7f) { + assert (write_out (c, &total_len, 1) == 1); + } else { + total_len = (total_len << 8) | 0x7f; + assert (write_out (c, &total_len, 4) == 4); + } + write_out (c, &unenc_msg_header, 20); + write_out (c, packet_buffer, len); + flush_out (c); + return 1; +} + +int rpc_send_message (struct connection *c, void *data, int len) { + assert (len > 0 && !(len & 0xfc000003)); + int total_len = len >> 2; + if (total_len < 0x7f) { + assert (write_out (c, &total_len, 1) == 1); + } else { + total_len = (total_len << 8) | 0x7f; + assert (write_out (c, &total_len, 4) == 4); + } + c->out_packet_num ++; + write_out (c, data, len); + flush_out (c); + return 1; +} + +int send_req_pq_packet (struct connection *c) { + assert (c_state == st_init); + assert (RAND_pseudo_bytes ((unsigned char *) nonce, 16) >= 0); + unenc_msg_header.out_msg_id = 0; + clear_packet (); + out_int (CODE_req_pq); + out_ints ((int *)nonce, 4); + rpc_send_packet (c); + c_state = st_reqpq_sent; + return 1; +} + + +unsigned long long gcd (unsigned long long a, unsigned long long b) { + return b ? gcd (b, a % b) : a; +} + +//typedef unsigned int uint128_t __attribute__ ((mode(TI))); +unsigned long long what; +unsigned p1, p2; + +int process_respq_answer (struct connection *c, char *packet, int len) { + int i; + if (verbosity) { + fprintf (stderr, "process_respq_answer(), len=%d\n", len); + } + assert (len >= 76); + assert (!*(long long *) packet); + assert (*(int *) (packet + 16) == len - 20); + assert (!(len & 3)); + assert (*(int *) (packet + 20) == CODE_resPQ); + assert (!memcmp (packet + 24, nonce, 16)); + memcpy (server_nonce, packet + 40, 16); + char *from = packet + 56; + int clen = *from++; + assert (clen <= 8); + what = 0; + for (i = 0; i < clen; i++) { + what = (what << 8) + (unsigned char)*from++; + } + + while (((unsigned long)from) & 3) ++from; + + p1 = 0, p2 = 0; + + if (verbosity >= 2) { + fprintf (stderr, "%lld received\n", what); + } + + int it = 0; + unsigned long long g = 0; + for (i = 0; i < 3 || it < 1000; i++) { + int q = ((lrand48() & 15) + 17) % what; + unsigned long long x = (long long)lrand48 () % (what - 1) + 1, y = x; + int lim = 1 << (i + 18); + int j; + for (j = 1; j < lim; j++) { + ++it; + unsigned long long a = x, b = x, c = q; + while (b) { + if (b & 1) { + c += a; + if (c >= what) { + c -= what; + } + } + a += a; + if (a >= what) { + a -= what; + } + b >>= 1; + } + x = c; + unsigned long long z = x < y ? what + x - y : x - y; + g = gcd (z, what); + if (g != 1) { + break; + } + if (!(j & (j - 1))) { + y = x; + } + } + if (g > 1 && g < what) break; + } + + assert (g > 1 && g < what); + p1 = g; + p2 = what / g; + if (p1 > p2) { + unsigned t = p1; p1 = p2; p2 = t; + } + + + if (verbosity) { + fprintf (stderr, "p1 = %d, p2 = %d, %d iterations\n", p1, p2, it); + } + + /// ++p1; /// + + assert (*(int *) (from) == CODE_vector); + int fingerprints_num = *(int *)(from + 4); + assert (fingerprints_num >= 1 && fingerprints_num <= 64 && len == fingerprints_num * 8 + 8 + (from - packet)); + long long *fingerprints = (long long *) (from + 8); + for (i = 0; i < fingerprints_num; i++) { + if (fingerprints[i] == pk_fingerprint) { + //fprintf (stderr, "found our public key at position %d\n", i); + break; + } + } + if (i == fingerprints_num) { + fprintf (stderr, "fatal: don't have any matching keys (%016llx expected)\n", pk_fingerprint); + exit (2); + } + // create inner part (P_Q_inner_data) + clear_packet (); + packet_ptr += 5; + out_int (CODE_p_q_inner_data); + out_cstring (packet + 57, clen); + //out_int (0x0f01); // pq=15 + + if (p1 < 256) { + clen = 1; + } else if (p1 < 65536) { + clen = 2; + } else if (p1 < 16777216) { + clen = 3; + } else { + clen = 4; + } + p1 = __builtin_bswap32 (p1); + out_cstring ((char *)&p1 + 4 - clen, clen); + p1 = __builtin_bswap32 (p1); + + if (p2 < 256) { + clen = 1; + } else if (p2 < 65536) { + clen = 2; + } else if (p2 < 16777216) { + clen = 3; + } else { + clen = 4; + } + p2 = __builtin_bswap32 (p2); + out_cstring ((char *)&p2 + 4 - clen, clen); + p2 = __builtin_bswap32 (p2); + + //out_int (0x0301); // p=3 + //out_int (0x0501); // q=5 + out_ints ((int *) nonce, 4); + out_ints ((int *) server_nonce, 4); + assert (RAND_pseudo_bytes ((unsigned char *) new_nonce, 32) >= 0); + out_ints ((int *) new_nonce, 8); + sha1 ((unsigned char *) (packet_buffer + 5), (packet_ptr - packet_buffer - 5) * 4, (unsigned char *) packet_buffer); + + int l = encrypt_packet_buffer (); + + clear_packet (); + out_int (CODE_req_DH_params); + out_ints ((int *) nonce, 4); + out_ints ((int *) server_nonce, 4); + //out_int (0x0301); // p=3 + //out_int (0x0501); // q=5 + if (p1 < 256) { + clen = 1; + } else if (p1 < 65536) { + clen = 2; + } else if (p1 < 16777216) { + clen = 3; + } else { + clen = 4; + } + p1 = __builtin_bswap32 (p1); + out_cstring ((char *)&p1 + 4 - clen, clen); + p1 = __builtin_bswap32 (p1); + if (p2 < 256) { + clen = 1; + } else if (p2 < 65536) { + clen = 2; + } else if (p2 < 16777216) { + clen = 3; + } else { + clen = 4; + } + p2 = __builtin_bswap32 (p2); + out_cstring ((char *)&p2 + 4 - clen, clen); + p2 = __builtin_bswap32 (p2); + + out_long (pk_fingerprint); + out_cstring ((char *) encrypt_buffer, l); + + c_state = st_reqdh_sent; + + return rpc_send_packet (c); +} + +int process_dh_answer (struct connection *c, char *packet, int len) { + if (verbosity) { + fprintf (stderr, "process_dh_answer(), len=%d\n", len); + } + if (len < 116) { + fprintf (stderr, "%u * %u = %llu", p1, p2, what); + } + assert (len >= 116); + assert (!*(long long *) packet); + assert (*(int *) (packet + 16) == len - 20); + assert (!(len & 3)); + assert (*(int *) (packet + 20) == (int)CODE_server_DH_params_ok); + assert (!memcmp (packet + 24, nonce, 16)); + assert (!memcmp (packet + 40, server_nonce, 16)); + init_aes_unauth (server_nonce, new_nonce, AES_DECRYPT); + in_ptr = (int *)(packet + 56); + in_end = (int *)(packet + len); + int l = prefetch_strlen (); + assert (l > 0); + l = pad_aes_decrypt (fetch_str (l), l, (char *) decrypt_buffer, DECRYPT_BUFFER_INTS * 4 - 16); + assert (in_ptr == in_end); + assert (l >= 60); + assert (decrypt_buffer[5] == (int)CODE_server_DH_inner_data); + assert (!memcmp (decrypt_buffer + 6, nonce, 16)); + assert (!memcmp (decrypt_buffer + 10, server_nonce, 16)); + assert (decrypt_buffer[14] == 2); + in_ptr = decrypt_buffer + 15; + in_end = decrypt_buffer + (l >> 2); + BN_init (&dh_prime); + BN_init (&g_a); + assert (fetch_bignum (&dh_prime) > 0); + assert (fetch_bignum (&g_a) > 0); + int server_time = *in_ptr++; + assert (in_ptr <= in_end); + + static char sha1_buffer[20]; + sha1 ((unsigned char *) decrypt_buffer + 20, (in_ptr - decrypt_buffer - 5) * 4, (unsigned char *) sha1_buffer); + assert (!memcmp (decrypt_buffer, sha1_buffer, 20)); + assert ((char *) in_end - (char *) in_ptr < 16); + + GET_DC(c)->server_time_delta = server_time - time (0); + GET_DC(c)->server_time_udelta = server_time - get_utime (CLOCK_MONOTONIC); + //fprintf (stderr, "server time is %d, delta = %d\n", server_time, server_time_delta); + + // Build set_client_DH_params answer + clear_packet (); + packet_ptr += 5; + out_int (CODE_client_DH_inner_data); + out_ints ((int *) nonce, 4); + out_ints ((int *) server_nonce, 4); + out_long (0LL); + + BN_init (&dh_g); + BN_set_word (&dh_g, 2); + + assert (RAND_pseudo_bytes ((unsigned char *)s_power, 256) >= 0); + BIGNUM *dh_power = BN_new (); + assert (BN_bin2bn ((unsigned char *)s_power, 256, dh_power) == dh_power); + + BIGNUM *y = BN_new (); + assert (BN_mod_exp (y, &dh_g, dh_power, &dh_prime, BN_ctx) == 1); + out_bignum (y); + BN_free (y); + + BN_init (&auth_key_num); + assert (BN_mod_exp (&auth_key_num, &g_a, dh_power, &dh_prime, BN_ctx) == 1); + l = BN_num_bytes (&auth_key_num); + assert (l >= 250 && l <= 256); + assert (BN_bn2bin (&auth_key_num, (unsigned char *)GET_DC(c)->auth_key)); + memset (GET_DC(c)->auth_key + l, 0, 256 - l); + BN_free (dh_power); + BN_free (&auth_key_num); + BN_free (&dh_g); + BN_free (&g_a); + BN_free (&dh_prime); + + //hexdump (auth_key, auth_key + 256); + + sha1 ((unsigned char *) (packet_buffer + 5), (packet_ptr - packet_buffer - 5) * 4, (unsigned char *) packet_buffer); + + //hexdump ((char *)packet_buffer, (char *)packet_ptr); + + l = encrypt_packet_buffer_aes_unauth (server_nonce, new_nonce); + + clear_packet (); + out_int (CODE_set_client_DH_params); + out_ints ((int *) nonce, 4); + out_ints ((int *) server_nonce, 4); + out_cstring ((char *) encrypt_buffer, l); + + c_state = st_client_dh_sent; + + return rpc_send_packet (c); +} + + +int process_auth_complete (struct connection *c UU, char *packet, int len) { + if (verbosity) { + fprintf (stderr, "process_dh_answer(), len=%d\n", len); + } + assert (len == 72); + assert (!*(long long *) packet); + assert (*(int *) (packet + 16) == len - 20); + assert (!(len & 3)); + assert (*(int *) (packet + 20) == CODE_dh_gen_ok); + assert (!memcmp (packet + 24, nonce, 16)); + assert (!memcmp (packet + 40, server_nonce, 16)); + static unsigned char tmp[44], sha1_buffer[20]; + memcpy (tmp, new_nonce, 32); + tmp[32] = 1; + sha1 ((unsigned char *)GET_DC(c)->auth_key, 256, sha1_buffer); + GET_DC(c)->auth_key_id = *(long long *)(sha1_buffer + 12); + memcpy (tmp + 33, sha1_buffer, 8); + sha1 (tmp, 41, sha1_buffer); + assert (!memcmp (packet + 56, sha1_buffer + 4, 16)); + GET_DC(c)->server_salt = *(long long *)server_nonce ^ *(long long *)new_nonce; + if (verbosity >= 3) { + fprintf (stderr, "auth_key_id=%016llx\n", GET_DC(c)->auth_key_id); + } + //kprintf ("OK\n"); + + //c->status = conn_error; + //sleep (1); + + c_state = st_authorized; + //return 1; + if (verbosity) { + fprintf (stderr, "Auth success\n"); + } + auth_success ++; + GET_DC(c)->flags |= 1; + write_auth_file (); + return 1; +} + +/* + * + * AUTHORIZED (MAIN) PROTOCOL PART + * + */ + +struct encrypted_message enc_msg; + +long long client_last_msg_id, server_last_msg_id; + +double get_server_time (struct dc *DC) { + if (!DC->server_time_udelta) { + DC->server_time_udelta = get_utime (CLOCK_REALTIME) - get_utime (CLOCK_MONOTONIC); + } + return get_utime (CLOCK_MONOTONIC) + DC->server_time_udelta; +} + +long long generate_next_msg_id (struct dc *DC) { + long long next_id = (long long) (get_server_time (DC) * (1LL << 32)) & -4; + if (next_id <= client_last_msg_id) { + next_id = client_last_msg_id += 4; + } else { + client_last_msg_id = next_id; + } + return next_id; +} + +void init_enc_msg (struct session *S, int useful) { + struct dc *DC = S->dc; + assert (DC->auth_key_id); + enc_msg.auth_key_id = DC->auth_key_id; + assert (DC->server_salt); + enc_msg.server_salt = DC->server_salt; + if (!S->session_id) { + assert (RAND_pseudo_bytes ((unsigned char *) &S->session_id, 8) >= 0); + } + enc_msg.session_id = S->session_id; + //enc_msg.auth_key_id2 = auth_key_id; + enc_msg.msg_id = generate_next_msg_id (DC); + //enc_msg.msg_id -= 0x10000000LL * (lrand48 () & 15); + //kprintf ("message id %016llx\n", enc_msg.msg_id); + enc_msg.seq_no = S->seq_no; + if (useful) { + enc_msg.seq_no |= 1; + } + S->seq_no += 2; +}; + +int aes_encrypt_message (struct dc *DC, struct encrypted_message *enc) { + unsigned char sha1_buffer[20]; + const int MINSZ = offsetof (struct encrypted_message, message); + const int UNENCSZ = offsetof (struct encrypted_message, server_salt); + int enc_len = (MINSZ - UNENCSZ) + enc->msg_len; + assert (enc->msg_len >= 0 && enc->msg_len <= MAX_MESSAGE_INTS * 4 - 16 && !(enc->msg_len & 3)); + sha1 ((unsigned char *) &enc->server_salt, enc_len, sha1_buffer); + //printf ("enc_len is %d\n", enc_len); + if (verbosity >= 2) { + fprintf (stderr, "sending message with sha1 %08x\n", *(int *)sha1_buffer); + } + memcpy (enc->msg_key, sha1_buffer + 4, 16); + init_aes_auth (DC->auth_key, enc->msg_key, AES_ENCRYPT); + //hexdump ((char *)enc, (char *)enc + enc_len + 24); + return pad_aes_encrypt ((char *) &enc->server_salt, enc_len, (char *) &enc->server_salt, MAX_MESSAGE_INTS * 4 + (MINSZ - UNENCSZ)); +} + +long long encrypt_send_message (struct connection *c, int *msg, int msg_ints, int useful) { + struct dc *DC = GET_DC(c); + struct session *S = c->session; + assert (S); + const int UNENCSZ = offsetof (struct encrypted_message, server_salt); + if (msg_ints <= 0 || msg_ints > MAX_MESSAGE_INTS - 4) { + return -1; + } + if (msg) { + memcpy (enc_msg.message, msg, msg_ints * 4); + enc_msg.msg_len = msg_ints * 4; + } else { + if ((enc_msg.msg_len & 0x80000003) || enc_msg.msg_len > MAX_MESSAGE_INTS * 4 - 16) { + return -1; + } + } + init_enc_msg (S, useful); + + //hexdump ((char *)msg, (char *)msg + (msg_ints * 4)); + int l = aes_encrypt_message (DC, &enc_msg); + //hexdump ((char *)&enc_msg, (char *)&enc_msg + l + 24); + assert (l > 0); + rpc_send_message (c, &enc_msg, l + UNENCSZ); + + return client_last_msg_id; +} + +int longpoll_count, good_messages; + +int auth_work_start (struct connection *c UU) { + return 1; +} + +void rpc_execute_answer (struct connection *c, long long msg_id UU); +void work_container (struct connection *c, long long msg_id UU) { + if (verbosity) { + fprintf (stderr, "work_container: msg_id = %lld\n", msg_id); + } + assert (fetch_int () == CODE_msg_container); + int n = fetch_int (); + int i; + for (i = 0; i < n; i++) { + long long id = fetch_long (); + int seqno = fetch_int (); + if (seqno & 1) { + insert_seqno (c->session, seqno); + } + int bytes = fetch_int (); + int *t = in_ptr; + rpc_execute_answer (c, id); + assert (in_ptr == t + (bytes / 4)); + } +} + +void work_new_session_created (struct connection *c, long long msg_id UU) { + if (verbosity) { + fprintf (stderr, "work_new_session_created: msg_id = %lld\n", msg_id); + } + assert (fetch_int () == (int)CODE_new_session_created); + fetch_long (); // first message id + //DC->session_id = fetch_long (); + fetch_long (); // unique_id + GET_DC(c)->server_salt = fetch_long (); +} + +void work_msgs_ack (struct connection *c UU, long long msg_id UU) { + if (verbosity) { + fprintf (stderr, "work_msgs_ack: msg_id = %lld\n", msg_id); + } + assert (fetch_int () == CODE_msgs_ack); + assert (fetch_int () == CODE_vector); + int n = fetch_int (); + int i; + for (i = 0; i < n; i++) { + long long id = fetch_long (); + query_ack (id); + } +} + +void work_rpc_result (struct connection *c UU, long long msg_id UU) { + if (verbosity) { + fprintf (stderr, "work_rpc_result: msg_id = %lld\n", msg_id); + } + assert (fetch_int () == (int)CODE_rpc_result); + long long id = fetch_long (); + int op = prefetch_int (); + if (op == CODE_rpc_error) { + query_error (id); + } else { + query_result (id); + } +} + +void rpc_execute_answer (struct connection *c, long long msg_id UU) { + int op = prefetch_int (); + switch (op) { + case CODE_msg_container: + work_container (c, msg_id); + return; + case CODE_new_session_created: + work_new_session_created (c, msg_id); + return; + case CODE_msgs_ack: + work_msgs_ack (c, msg_id); + return; + case CODE_rpc_result: + work_rpc_result (c, msg_id); + return; + } + fprintf (stderr, "Unknown message: \n"); + hexdump_in (); +} + +int process_rpc_message (struct connection *c UU, struct encrypted_message *enc, int len) { + const int MINSZ = offsetof (struct encrypted_message, message); + const int UNENCSZ = offsetof (struct encrypted_message, server_salt); + if (verbosity) { + fprintf (stderr, "process_rpc_message(), len=%d\n", len); + } + assert (len >= MINSZ && (len & 15) == (UNENCSZ & 15)); + struct dc *DC = GET_DC(c); + assert (enc->auth_key_id == DC->auth_key_id); + assert (DC->auth_key_id); + init_aes_auth (DC->auth_key + 8, enc->msg_key, AES_DECRYPT); + int l = pad_aes_decrypt ((char *)&enc->server_salt, len - UNENCSZ, (char *)&enc->server_salt, len - UNENCSZ); + assert (l == len - UNENCSZ); + //assert (enc->auth_key_id2 == enc->auth_key_id); + assert (!(enc->msg_len & 3) && enc->msg_len > 0 && enc->msg_len <= len - MINSZ && len - MINSZ - enc->msg_len <= 12); + static unsigned char sha1_buffer[20]; + sha1 ((void *)&enc->server_salt, enc->msg_len + (MINSZ - UNENCSZ), sha1_buffer); + assert (!memcmp (&enc->msg_key, sha1_buffer + 4, 16)); + //assert (enc->server_salt == server_salt); //in fact server salt can change + if (DC->server_salt != enc->server_salt) { + DC->server_salt = enc->server_salt; + write_auth_file (); + } + int this_server_time = enc->msg_id >> 32LL; + double st = get_server_time (DC); + assert (this_server_time >= st - 300 && this_server_time <= st + 30); + //assert (enc->msg_id > server_last_msg_id && (enc->msg_id & 3) == 1); + if (verbosity >= 2) { + fprintf (stderr, "received mesage id %016llx\n", enc->msg_id); + } + server_last_msg_id = enc->msg_id; + + //*(long long *)(longpoll_query + 3) = *(long long *)((char *)(&enc->msg_id) + 0x3c); + //*(long long *)(longpoll_query + 5) = *(long long *)((char *)(&enc->msg_id) + 0x3c); + + assert (l >= (MINSZ - UNENCSZ) + 8); + //assert (enc->message[0] == CODE_rpc_result && *(long long *)(enc->message + 1) == client_last_msg_id); + if (verbosity >= 2) { + fprintf (stderr, "OK, message is good!\n"); + } + ++good_messages; + + in_ptr = enc->message; + in_end = in_ptr + (enc->msg_len / 4); + + if (enc->seq_no & 1) { + insert_seqno (c->session, enc->seq_no); + } + assert (c->session->session_id == enc->session_id); + rpc_execute_answer (c, enc->msg_id); + return 0; +} + + +int rpc_execute (struct connection *c, int op, int len) { + if (verbosity) { + fprintf (stderr, "outbound rpc connection #%d : received rpc answer %d with %d content bytes\n", c->fd, op, len); + } + + if (len >= MAX_RESPONSE_SIZE/* - 12*/ || len < 0/*12*/) { + fprintf (stderr, "answer too long (%d bytes), skipping\n", len); + return 0; + } + + int Response_len = len; + + assert (read_in (c, Response, Response_len) == Response_len); + Response[Response_len] = 0; + if (verbosity >= 2) { + fprintf (stderr, "have %d Response bytes\n", Response_len); + } + + setsockopt (c->fd, IPPROTO_TCP, TCP_QUICKACK, (int[]){0}, 4); + int o = c_state; + if (GET_DC(c)->flags & 1) { o = st_authorized;} + switch (o) { + case st_reqpq_sent: + process_respq_answer (c, Response/* + 8*/, Response_len/* - 12*/); + setsockopt (c->fd, IPPROTO_TCP, TCP_QUICKACK, (int[]){0}, 4); + return 0; + case st_reqdh_sent: + process_dh_answer (c, Response/* + 8*/, Response_len/* - 12*/); + setsockopt (c->fd, IPPROTO_TCP, TCP_QUICKACK, (int[]){0}, 4); + return 0; + case st_client_dh_sent: + process_auth_complete (c, Response/* + 8*/, Response_len/* - 12*/); + setsockopt (c->fd, IPPROTO_TCP, TCP_QUICKACK, (int[]){0}, 4); + return 0; + case st_authorized: + process_rpc_message (c, (void *)(Response/* + 8*/), Response_len/* - 12*/); + setsockopt (c->fd, IPPROTO_TCP, TCP_QUICKACK, (int[]){0}, 4); + return 0; + default: + fprintf (stderr, "fatal: cannot receive answer in state %d\n", c_state); + exit (2); + } + + return 0; +} + + +int tc_close (struct connection *c, int who) { + if (verbosity) { + fprintf (stderr, "outbound http connection #%d : closing by %d\n", c->fd, who); + } + return 0; +} + +int tc_becomes_ready (struct connection *c) { + if (verbosity) { + fprintf (stderr, "outbound connection #%d becomes ready\n", c->fd); + } + char byte = 0xef; + assert (write_out (c, &byte, 1) == 1); + flush_out (c); + + setsockopt (c->fd, IPPROTO_TCP, TCP_QUICKACK, (int[]){0}, 4); + int o = c_state; + if (GET_DC(c)->flags & 1) { o = st_authorized; } + switch (o) { + case st_init: + send_req_pq_packet (c); + break; + case st_authorized: + auth_work_start (c); + break; + default: + fprintf (stderr, "c_state = %d\n", c_state); + assert (0); + } + return 0; +} + +int rpc_becomes_ready (struct connection *c) { + return tc_becomes_ready (c); +} + +int rpc_close (struct connection *c) { + return tc_close (c, 0); +} + +int auth_is_success (void) { + return auth_success; +} + +void on_start (void) { + prng_seed (0, 0); + + if (rsa_load_public_key (rsa_public_key_name) < 0) { + perror ("rsa_load_public_key"); + exit (1); + } + if (verbosity) { + fprintf (stderr, "public key '%s' loaded successfully\n", rsa_public_key_name); + } + pk_fingerprint = compute_rsa_key_fingerprint (pubKey); +} + +int auth_ok (void) { + return auth_success; +} + +void dc_authorize (struct dc *DC) { + c_state = 0; + auth_success = 0; + if (!DC->sessions[0]) { + dc_create_session (DC); + } + if (verbosity) { + fprintf (stderr, "Starting authorization for DC #%d: %s:%d\n", DC->id, DC->ip, DC->port); + } + net_loop (0, auth_ok); +} diff --git a/mtproto-client.h b/mtproto-client.h new file mode 100644 index 0000000..200e128 --- /dev/null +++ b/mtproto-client.h @@ -0,0 +1,7 @@ +#ifndef __MTPROTO_CLIENT_H__ +#define __MTPROTO_CLIENT_H__ +#include "net.h" +void on_start (void); +long long encrypt_send_message (struct connection *c, int *msg, int msg_ints, int useful); +void dc_authorize (struct dc *DC); +#endif diff --git a/mtproto-common.c b/mtproto-common.c new file mode 100644 index 0000000..e00845b --- /dev/null +++ b/mtproto-common.c @@ -0,0 +1,337 @@ +#define _FILE_OFFSET_BITS 64 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mtproto-common.h" + +long long rsa_encrypted_chunks, rsa_decrypted_chunks; + +BN_CTX *BN_ctx; +int verbosity; + +int get_random_bytes (void *buf, int n) { + int r = 0, h = open ("/dev/random", O_RDONLY | O_NONBLOCK); + if (h >= 0) { + r = read (h, buf, n); + if (r > 0) { + if (verbosity >= 3) { + fprintf (stderr, "added %d bytes of real entropy to secure random numbers seed\n", r); + } + } + close (h); + } + + if (r < n) { + h = open ("/dev/urandom", O_RDONLY); + if (h < 0) { + return r; + } + int s = read (h, buf + r, n - r); + close (h); + if (s < 0) { + return r; + } + r += s; + } + + if (r >= (int)sizeof (long)) { + *(long *)buf ^= lrand48 (); + srand48 (*(long *)buf); + } + + return r; +} + + +void prng_seed (const char *password_filename, int password_length) { + unsigned char *a = calloc (64 + password_length, 1); + assert (a != NULL); + long long r = rdtsc (); + struct timespec T; + assert (clock_gettime(CLOCK_REALTIME, &T) >= 0); + memcpy (a, &T.tv_sec, 4); + memcpy (a+4, &T.tv_nsec, 4); + memcpy (a+8, &r, 8); + unsigned short p = getpid (); + memcpy (a + 16, &p, 2); + int s = get_random_bytes (a + 18, 32) + 18; + if (password_filename) { + int fd = open (password_filename, O_RDONLY); + if (fd < 0) { + fprintf (stderr, "Warning: fail to open password file - \"%s\", %m.\n", password_filename); + } else { + int l = read (fd, a + s, password_length); + if (l < 0) { + fprintf (stderr, "Warning: fail to read password file - \"%s\", %m.\n", password_filename); + } else { + if (verbosity > 0) { + fprintf (stderr, "read %d bytes from password file.\n", l); + } + s += l; + } + close (fd); + } + } + RAND_seed (a, s); + BN_ctx = BN_CTX_new (); + memset (a, 0, s); + free (a); +} + +int serialize_bignum (BIGNUM *b, char *buffer, int maxlen) { + int itslen = BN_num_bytes (b); + int reqlen; + if (itslen < 254) { + reqlen = itslen + 1; + } else { + reqlen = itslen + 4; + } + int newlen = (reqlen + 3) & -4; + int pad = newlen - reqlen; + reqlen = newlen; + if (reqlen > maxlen) { + return -reqlen; + } + if (itslen < 254) { + *buffer++ = itslen; + } else { + *(int *)buffer = (itslen << 8) + 0xfe; + buffer += 4; + } + int l = BN_bn2bin (b, (unsigned char *)buffer); + assert (l == itslen); + buffer += l; + while (pad --> 0) { + *buffer++ = 0; + } + return reqlen; +} + + +long long compute_rsa_key_fingerprint (RSA *key) { + static char tempbuff[4096]; + static unsigned char sha[20]; + assert (key->n && key->e); + int l1 = serialize_bignum (key->n, tempbuff, 4096); + assert (l1 > 0); + int l2 = serialize_bignum (key->e, tempbuff + l1, 4096 - l1); + assert (l2 > 0 && l1 + l2 <= 4096); + SHA1 ((unsigned char *)tempbuff, l1 + l2, sha); + return *(long long *)(sha + 12); +} + +void out_cstring (const char *str, long len) { + assert (len >= 0 && len < (1 << 24)); + assert ((char *) packet_ptr + len + 8 < (char *) (packet_buffer + PACKET_BUFFER_SIZE)); + char *dest = (char *) packet_ptr; + if (len < 254) { + *dest++ = len; + } else { + *packet_ptr = (len << 8) + 0xfe; + dest += 4; + } + memcpy (dest, str, len); + dest += len; + while ((long) dest & 3) { + *dest++ = 0; + } + packet_ptr = (int *) dest; +} + +void out_cstring_careful (const char *str, long len) { + assert (len >= 0 && len < (1 << 24)); + assert ((char *) packet_ptr + len + 8 < (char *) (packet_buffer + PACKET_BUFFER_SIZE)); + char *dest = (char *) packet_ptr; + if (len < 254) { + dest++; + if (dest != str) { + memmove (dest, str, len); + } + dest[-1] = len; + } else { + dest += 4; + if (dest != str) { + memmove (dest, str, len); + } + *packet_ptr = (len << 8) + 0xfe; + } + dest += len; + while ((long) dest & 3) { + *dest++ = 0; + } + packet_ptr = (int *) dest; +} + + +void out_data (const char *data, long len) { + assert (len >= 0 && len < (1 << 24) && !(len & 3)); + assert ((char *) packet_ptr + len + 8 < (char *) (packet_buffer + PACKET_BUFFER_SIZE)); + memcpy (packet_ptr, data, len); + packet_ptr += len >> 2; +} + +int *in_ptr, *in_end; + +int fetch_bignum (BIGNUM *x) { + int l = prefetch_strlen (); + if (l < 0) { + return l; + } + char *str = fetch_str (l); + assert (BN_bin2bn ((unsigned char *) str, l, x) == x); + return l; +} + +int pad_rsa_encrypt (char *from, int from_len, char *to, int size, BIGNUM *N, BIGNUM *E) { + int pad = (255000 - from_len - 32) % 255 + 32; + int chunks = (from_len + pad) / 255; + int bits = BN_num_bits (N); + assert (bits >= 2041 && bits <= 2048); + assert (from_len > 0 && from_len <= 2550); + assert (size >= chunks * 256); + assert (RAND_pseudo_bytes ((unsigned char *) from + from_len, pad) >= 0); + int i; + BIGNUM x, y; + BN_init (&x); + BN_init (&y); + rsa_encrypted_chunks += chunks; + for (i = 0; i < chunks; i++) { + BN_bin2bn ((unsigned char *) from, 255, &x); + assert (BN_mod_exp (&y, &x, E, N, BN_ctx) == 1); + unsigned l = 256 - BN_num_bytes (&y); + assert (l <= 256); + memset (to, 0, l); + BN_bn2bin (&y, (unsigned char *) to + l); + to += 256; + } + BN_free (&x); + BN_free (&y); + return chunks * 256; +} + +int pad_rsa_decrypt (char *from, int from_len, char *to, int size, BIGNUM *N, BIGNUM *D) { + if (from_len < 0 || from_len > 0x1000 || (from_len & 0xff)) { + return -1; + } + int chunks = (from_len >> 8); + int bits = BN_num_bits (N); + assert (bits >= 2041 && bits <= 2048); + assert (size >= chunks * 255); + int i; + BIGNUM x, y; + BN_init (&x); + BN_init (&y); + for (i = 0; i < chunks; i++) { + ++rsa_decrypted_chunks; + BN_bin2bn ((unsigned char *) from, 256, &x); + assert (BN_mod_exp (&y, &x, D, N, BN_ctx) == 1); + int l = BN_num_bytes (&y); + if (l > 255) { + BN_free (&x); + BN_free (&y); + return -1; + } + assert (l >= 0 && l <= 255); + memset (to, 0, 255 - l); + BN_bn2bin (&y, (unsigned char *) to + 255 - l); + to += 255; + } + BN_free (&x); + BN_free (&y); + return chunks * 255; +} + +unsigned char aes_key_raw[32], aes_iv[32]; +AES_KEY aes_key; + +void init_aes_unauth (const char server_nonce[16], const char hidden_client_nonce[32], int encrypt) { + static unsigned char buffer[64], hash[20]; + memcpy (buffer, hidden_client_nonce, 32); + memcpy (buffer + 32, server_nonce, 16); + SHA1 (buffer, 48, aes_key_raw); + memcpy (buffer + 32, hidden_client_nonce, 32); + SHA1 (buffer, 64, aes_iv + 8); + memcpy (buffer, server_nonce, 16); + memcpy (buffer + 16, hidden_client_nonce, 32); + SHA1 (buffer, 48, hash); + memcpy (aes_key_raw + 20, hash, 12); + memcpy (aes_iv, hash + 12, 8); + memcpy (aes_iv + 28, hidden_client_nonce, 4); + if (encrypt == AES_ENCRYPT) { + AES_set_encrypt_key (aes_key_raw, 32*8, &aes_key); + } else { + AES_set_decrypt_key (aes_key_raw, 32*8, &aes_key); + } +} + +void init_aes_auth (char auth_key[192], char msg_key[16], int encrypt) { + static unsigned char buffer[48], hash[20]; + // sha1_a = SHA1 (msg_key + substr (auth_key, 0, 32)); + // sha1_b = SHA1 (substr (auth_key, 32, 16) + msg_key + substr (auth_key, 48, 16)); + // sha1_с = SHA1 (substr (auth_key, 64, 32) + msg_key); + // sha1_d = SHA1 (msg_key + substr (auth_key, 96, 32)); + // aes_key = substr (sha1_a, 0, 8) + substr (sha1_b, 8, 12) + substr (sha1_c, 4, 12); + // aes_iv = substr (sha1_a, 8, 12) + substr (sha1_b, 0, 8) + substr (sha1_c, 16, 4) + substr (sha1_d, 0, 8); + memcpy (buffer, msg_key, 16); + memcpy (buffer + 16, auth_key, 32); + SHA1 (buffer, 48, hash); + memcpy (aes_key_raw, hash, 8); + memcpy (aes_iv, hash + 8, 12); + + memcpy (buffer, auth_key + 32, 16); + memcpy (buffer + 16, msg_key, 16); + memcpy (buffer + 32, auth_key + 48, 16); + SHA1 (buffer, 48, hash); + memcpy (aes_key_raw + 8, hash + 8, 12); + memcpy (aes_iv + 12, hash, 8); + + memcpy (buffer, auth_key + 64, 32); + memcpy (buffer + 32, msg_key, 16); + SHA1 (buffer, 48, hash); + memcpy (aes_key_raw + 20, hash + 4, 12); + memcpy (aes_iv + 20, hash + 16, 4); + + memcpy (buffer, msg_key, 16); + memcpy (buffer + 16, auth_key + 96, 32); + SHA1 (buffer, 48, hash); + memcpy (aes_iv + 24, hash, 8); + + if (encrypt == AES_ENCRYPT) { + AES_set_encrypt_key (aes_key_raw, 32*8, &aes_key); + } else { + AES_set_decrypt_key (aes_key_raw, 32*8, &aes_key); + } +} + +int pad_aes_encrypt (char *from, int from_len, char *to, int size) { + int padded_size = (from_len + 15) & -16; + assert (from_len > 0 && padded_size <= size); + if (from_len < padded_size) { + assert (RAND_pseudo_bytes ((unsigned char *) from + from_len, padded_size - from_len) >= 0); + } + AES_ige_encrypt ((unsigned char *) from, (unsigned char *) to, padded_size, &aes_key, aes_iv, AES_ENCRYPT); + return padded_size; +} + +int pad_aes_decrypt (char *from, int from_len, char *to, int size) { + if (from_len <= 0 || from_len > size || (from_len & 15)) { + return -1; + } + AES_ige_encrypt ((unsigned char *) from, (unsigned char *) to, from_len, &aes_key, aes_iv, AES_DECRYPT); + return from_len; +} + + diff --git a/mtproto-common.h b/mtproto-common.h new file mode 100644 index 0000000..808aa89 --- /dev/null +++ b/mtproto-common.h @@ -0,0 +1,416 @@ +#ifndef __MTPROTO_COMMON_H__ +#define __MTPROTO_COMMON_H__ + +#include +#include +#include +#include +#include + +/* DH key exchange protocol data structures */ +#define CODE_req_pq 0x60469778 +#define CODE_resPQ 0x05162463 +#define CODE_req_DH_params 0xd712e4be +#define CODE_p_q_inner_data 0x83c95aec +#define CODE_server_DH_inner_data 0xb5890dba +#define CODE_server_DH_params_fail 0x79cb045d +#define CODE_server_DH_params_ok 0xd0e8075c +#define CODE_set_client_DH_params 0xf5045f1f +#define CODE_client_DH_inner_data 0x6643b654 +#define CODE_dh_gen_ok 0x3bcbf734 +#define CODE_dh_gen_retry 0x46dc1fb9 +#define CODE_dh_gen_fail 0xa69dae02 + +/* generic data structures */ +#define CODE_vector_long 0xc734a64e +#define CODE_vector_int 0xa03855ae +#define CODE_vector_Object 0xa351ae8e +#define CODE_vector 0x1cb5c415 + +/* service messages */ +#define CODE_rpc_result 0xf35c6d01 +#define CODE_rpc_error 0x2144ca19 +#define CODE_msg_container 0x73f1f8dc +#define CODE_msg_copy 0xe06046b2 +#define CODE_http_wait 0x9299359f +#define CODE_msgs_ack 0x62d6b459 +#define CODE_bad_msg_notification 0xa7eff811 +#define CODE_bad_server_salt 0xedab447b +#define CODE_msgs_state_req 0xda69fb52 +#define CODE_msgs_state_info 0x04deb57d +#define CODE_msgs_all_info 0x8cc0d131 +#define CODE_new_session_created 0x9ec20908 +#define CODE_msg_resend_req 0x7d861a08 +#define CODE_ping 0x7abe77ec +#define CODE_pong 0x347773c5 +#define CODE_destroy_session 0xe7512126 +#define CODE_destroy_session_ok 0xe22045fc +#define CODE_destroy_session_none 0x62d350c9 +#define CODE_destroy_sessions 0x9a6face8 +#define CODE_destroy_sessions_res 0xa8164668 +#define CODE_get_future_salts 0xb921bd04 +#define CODE_future_salt 0x0949d9dc +#define CODE_future_salts 0xae500895 +#define CODE_rpc_drop_answer 0x58e4a740 +#define CODE_rpc_answer_unknown 0x5e2ad36e +#define CODE_rpc_answer_dropped_running 0xcd78e586 +#define CODE_rpc_answer_dropped 0xa43ad8b7 +#define CODE_msg_detailed_info 0x276d3ec6 +#define CODE_msg_new_detailed_info 0x809db6df +#define CODE_ping_delay_disconnect 0xf3427b8c + +/* sample rpc query/response structures */ +#define CODE_getUser 0xb0f732d5 +#define CODE_getUsers 0x2d84d5f5 +#define CODE_user 0xd23c81a3 +#define CODE_no_user 0xc67599d1 + +#define CODE_msgs_random 0x12345678 +#define CODE_random_msg 0x87654321 + +#define RPC_INVOKE_REQ 0x2374df3d +#define RPC_INVOKE_KPHP_REQ 0x99a37fda +#define RPC_REQ_RUNNING 0x346d5efa +#define RPC_REQ_ERROR 0x7ae432f5 +#define RPC_REQ_RESULT 0x63aeda4e +#define RPC_READY 0x6a34cac7 +#define RPC_STOP_READY 0x59d86654 +#define RPC_SEND_SESSION_MSG 0x1ed5a3cc +#define RPC_RESPONSE_INDIRECT 0x2194f56e + +/* RPC for workers */ +#define CODE_send_session_msg 0x81bb412c +#define CODE_sendMsgOk 0x29841ee2 +#define CODE_sendMsgNoSession 0x2b2b9e78 +#define CODE_sendMsgFailed 0x4b0cbd57 +#define CODE_get_auth_sessions 0x611f7845 +#define CODE_authKeyNone 0x8a8bc1f3 +#define CODE_authKeySessions 0x6b7f026c +#define CODE_add_session_box 0xe707e295 +#define CODE_set_session_box 0x193d4231 +#define CODE_replace_session_box 0xcb101b49 +#define CODE_replace_session_box_cas 0xb2bbfa78 +#define CODE_delete_session_box 0x01b78d81 +#define CODE_delete_session_box_cas 0xb3fdc3c5 +#define CODE_session_box_no_session 0x43f46c33 +#define CODE_session_box_created 0xe1dd5d40 +#define CODE_session_box_replaced 0xbd9cb6b2 +#define CODE_session_box_deleted 0xaf8fd05e +#define CODE_session_box_not_found 0xb3560a7f +#define CODE_session_box_found 0x560fe356 +#define CODE_session_box_changed 0x014b31b8 +#define CODE_get_session_box 0x8793a924 +#define CODE_get_session_box_cond 0x7888fab6 +#define CODE_session_box_session_absent 0x9e234062 +#define CODE_session_box_absent 0xa1a106eb +#define CODE_session_box 0x7956cd97 +#define CODE_session_box_large 0xb568d189 +#define CODE_get_sessions_activity 0x059dc5f6 +#define CODE_sessions_activities 0x60ce5b1d +#define CODE_get_session_activity 0x96dbac11 +#define CODE_session_activity 0xe175e8e0 + +/* RPC for front/proxy */ +#define RPC_FRONT 0x27a456f3 +#define RPC_FRONT_ACK 0x624abd23 +#define RPC_FRONT_ERR 0x71dda175 +#define RPC_PROXY_REQ 0x36cef1ee +#define RPC_PROXY_ANS 0x4403da0d +#define RPC_CLOSE_CONN 0x1fcf425d +#define RPC_CLOSE_EXT 0x5eb634a2 +#define RPC_SIMPLE_ACK 0x3bac409b + + + +#define CODE_auth_send_code 0xd16ff372 +#define CODE_auth_sent_code 0x2215bcbd +#define CODE_help_get_config 0xc4f9186b +#define CODE_config 0x232d5905 +#define CODE_dc_option 0x2ec2a43c +#define CODE_bool_false 0xbc799737 +#define CODE_bool_true 0x997275b5 +#define CODE_user_self 0x720535ec +#define CODE_auth_authorization 0xf6b673a4 +#define CODE_user_profile_photo_empty 0x4f11bae1 +#define CODE_user_profile_photo 0x990d1493 +#define CODE_user_status_empty 0x9d05049 +#define CODE_user_status_online 0xedb93949 +#define CODE_user_status_offline 0x8c703f +#define CODE_sign_in 0xbcd51581 +#define CODE_file_location 0x53d69076 +#define CODE_file_location_unavailable 0x7c596b46 +#define CODE_contacts_get_contacts 0x22c6aa08 +#define CODE_contacts_contacts 0x6f8b8cb2 +#define CODE_contact 0xf911c994 +#define CODE_user_empty 0x200250ba +#define CODE_user_contact 0xf2fb8319 +#define CODE_user_request 0x22e8ceb0 +#define CODE_user_foreign 0x5214c89d +#define CODE_user_deleted 0xb29ad7cc +#define CODE_gzip_packed 0x3072cfa1 + + +/* not really a limit, for struct encrypted_message only */ +// #define MAX_MESSAGE_INTS 16384 +#define MAX_MESSAGE_INTS 1048576 +#define MAX_PROTO_MESSAGE_INTS 1048576 + +#pragma pack(push,4) +struct encrypted_message { + // unencrypted header + long long auth_key_id; + char msg_key[16]; + // encrypted part, starts with encrypted header + long long server_salt; + long long session_id; + // long long auth_key_id2; // removed + // first message follows + long long msg_id; + int seq_no; + int msg_len; // divisible by 4 + int message[MAX_MESSAGE_INTS]; +}; + +struct worker_descr { + int addr; + int port; + int pid; + int start_time; + int id; +}; + +struct rpc_ready_packet { + int len; + int seq_num; + int type; + struct worker_descr worker; + int worker_ready_cnt; + int crc32; +}; + + +struct front_descr { + int addr; + int port; + int pid; + int start_time; + int id; +}; + +struct rpc_front_packet { + int len; + int seq_num; + int type; + struct front_descr front; + long long hash_mult; + int rem, mod; + int crc32; +}; + +struct middle_descr { + int addr; + int port; + int pid; + int start_time; + int id; +}; + +struct rpc_front_ack { + int len; + int seq_num; + int type; + struct middle_descr middle; + int crc32; +}; + +struct rpc_front_err { + int len; + int seq_num; + int type; + int errcode; + struct middle_descr middle; + long long hash_mult; + int rem, mod; + int crc32; +}; + +struct rpc_proxy_req { + int len; + int seq_num; + int type; + int flags; + long long ext_conn_id; + unsigned char remote_ipv6[16]; + int remote_port; + unsigned char our_ipv6[16]; + int our_port; + int data[]; +}; + +#define PROXY_HDR(__x) ((struct rpc_proxy_req *)((__x) - offsetof(struct rpc_proxy_req, data))) + +struct rpc_proxy_ans { + int len; + int seq_num; + int type; + int flags; // +16 = small error packet, +8 = flush immediately + long long ext_conn_id; + int data[]; +}; + +struct rpc_close_conn { + int len; + int seq_num; + int type; + long long ext_conn_id; + int crc32; +}; + +struct rpc_close_ext { + int len; + int seq_num; + int type; + long long ext_conn_id; + int crc32; +}; + +struct rpc_simple_ack { + int len; + int seq_num; + int type; + long long ext_conn_id; + int confirm_key; + int crc32; +}; + +#pragma pack(pop) + +BN_CTX *BN_ctx; + +void prng_seed (const char *password_filename, int password_length); +int serialize_bignum (BIGNUM *b, char *buffer, int maxlen); +long long compute_rsa_key_fingerprint (RSA *key); + +#define PACKET_BUFFER_SIZE (16384 * 100) // temp fix +int packet_buffer[PACKET_BUFFER_SIZE], *packet_ptr; + +static inline void out_ints (int *what, int len) { + assert (packet_ptr + len <= packet_buffer + PACKET_BUFFER_SIZE); + memcpy (packet_ptr, what, len * 4); + packet_ptr += len; +} + + +static inline void out_int (int x) { + assert (packet_ptr + 1 <= packet_buffer + PACKET_BUFFER_SIZE); + *packet_ptr++ = x; +} + + +static inline void out_long (long long x) { + assert (packet_ptr + 2 <= packet_buffer + PACKET_BUFFER_SIZE); + *(long long *)packet_ptr = x; + packet_ptr += 2; +} + +static inline void clear_packet (void) { + packet_ptr = packet_buffer; +} + +void out_cstring (const char *str, long len); +void out_cstring_careful (const char *str, long len); +void out_data (const char *data, long len); + +static inline void out_string (const char *str) { + out_cstring (str, strlen (str)); +} + +static inline void out_bignum (BIGNUM *n) { + int l = serialize_bignum (n, (char *)packet_ptr, (PACKET_BUFFER_SIZE - (packet_ptr - packet_buffer)) * 4); + assert (l > 0); + packet_ptr += l >> 2; +} + +extern int *in_ptr, *in_end; + +static inline int prefetch_strlen (void) { + if (in_ptr >= in_end) { + return -1; + } + unsigned l = *in_ptr; + if ((l & 0xff) < 0xfe) { + l &= 0xff; + return (in_end >= in_ptr + (l >> 2) + 1) ? (int)l : -1; + } else if ((l & 0xff) == 0xfe) { + l >>= 8; + return (l >= 254 && in_end >= in_ptr + ((l + 7) >> 2)) ? (int)l : -1; + } else { + return -1; + } +} + + +static inline char *fetch_str (int len) { + if (len < 254) { + char *str = (char *) in_ptr + 1; + in_ptr += 1 + (len >> 2); + return str; + } else { + char *str = (char *) in_ptr + 4; + in_ptr += (len + 7) >> 2; + return str; + } +} + +static inline char *fetch_str_dup (void) { + int l = prefetch_strlen (); + return strndup (fetch_str (l), l); +} + +static __inline__ unsigned long long rdtsc(void) { + unsigned hi, lo; + __asm__ __volatile__ ("rdtsc" : "=a"(lo), "=d"(hi)); + return ( (unsigned long long)lo)|( ((unsigned long long)hi)<<32 ); +} + +static inline long have_prefetch_ints (void) { + return in_end - in_ptr; +} + +int fetch_bignum (BIGNUM *x); + +static inline int fetch_int (void) { + return *(in_ptr ++); +} + +static inline int prefetch_int (void) { + return *(in_ptr); +} + +static inline long long fetch_long (void) { + long long r = *(long long *)in_ptr; + in_ptr += 2; + return r; +} + +int get_random_bytes (void *buf, int n); + +int pad_rsa_encrypt (char *from, int from_len, char *to, int size, BIGNUM *N, BIGNUM *E); +int pad_rsa_decrypt (char *from, int from_len, char *to, int size, BIGNUM *N, BIGNUM *D); + +extern long long rsa_encrypted_chunks, rsa_decrypted_chunks; + +extern unsigned char aes_key_raw[32], aes_iv[32]; +extern AES_KEY aes_key; + +void init_aes_unauth (const char server_nonce[16], const char hidden_client_nonce[32], int encrypt); +void init_aes_auth (char auth_key[192], char msg_key[16], int encrypt); +int pad_aes_encrypt (char *from, int from_len, char *to, int size); +int pad_aes_decrypt (char *from, int from_len, char *to, int size); + +static inline void hexdump_in (void) { + int *ptr = in_ptr; + while (ptr < in_end) { fprintf (stderr, " %08x", *(ptr ++)); } + fprintf (stderr, "\n"); +} +#endif diff --git a/net.c b/net.c new file mode 100644 index 0000000..69e5c4c --- /dev/null +++ b/net.c @@ -0,0 +1,436 @@ +#define _GNU_SOURCE +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "net.h" +#include "include.h" +#include "mtproto-client.h" +#include "mtproto-common.h" +#include "tree.h" + +DEFINE_TREE(int,int,int_cmp,0) + +int verbosity; +extern struct connection_methods auth_methods; + +struct connection_buffer *new_connection_buffer (int size) { + struct connection_buffer *b = malloc (sizeof (*b)); + memset (b, 0, sizeof (*b)); + b->start = malloc (size); + b->end = b->start + size; + b->rptr = b->wptr = b->start; + return b; +} + +void delete_connection_buffer (struct connection_buffer *b) { + free (b->start); + free (b); +} + +int write_out (struct connection *c, const void *data, int len) { + if (!len) { return 0; } + assert (len > 0); + int x = 0; + if (!c->out_head) { + struct connection_buffer *b = new_connection_buffer (1 << 20); + c->out_head = c->out_tail = b; + } + while (len) { + if (c->out_tail->end - c->out_tail->wptr >= len) { + memcpy (c->out_tail->wptr, data, len); + c->out_tail->wptr += len; + c->out_bytes += len; + return x + len; + } else { + int y = c->out_tail->end - c->out_tail->wptr; + assert (y < len); + memcpy (c->out_tail->wptr, data, y); + x += y; + len -= y; + data += y; + struct connection_buffer *b = new_connection_buffer (1 << 20); + c->out_tail->next = b; + b->next = 0; + c->out_tail = b; + c->out_bytes += y; + } + } + return x; +} + +int read_in (struct connection *c, void *data, int len) { + if (!len) { return 0; } + assert (len > 0); + if (len > c->in_bytes) { + len = c->in_bytes; + } + int x = 0; + while (len) { + int y = c->in_head->wptr - c->in_head->rptr; + if (y > len) { + memcpy (data, c->in_head->rptr, len); + c->in_head->rptr += len; + c->in_bytes -= len; + return x + len; + } else { + memcpy (data, c->in_head->rptr, y); + c->in_bytes -= y; + x += y; + data += y; + len -= y; + void *old = c->in_head; + c->in_head = c->in_head->next; + if (!c->in_head) { + c->in_tail = 0; + } + delete_connection_buffer (old); + } + } + return x; +} + +int read_in_lookup (struct connection *c, void *data, int len) { + if (!len) { return 0; } + assert (len > 0); + if (len > c->in_bytes) { + len = c->in_bytes; + } + int x = 0; + struct connection_buffer *b = c->in_head; + while (len) { + int y = b->wptr - b->rptr; + if (y > len) { + memcpy (data, b->rptr, len); + return x + len; + } else { + memcpy (data, b->rptr, y); + x += y; + b = b->next; + } + } + return x; +} + +void flush_out (struct connection *c UU) { +} + +#define MAX_CONNECTIONS 100 +struct connection *Connections[MAX_CONNECTIONS]; +int max_connection_fd; + +struct connection *create_connection (const char *host, int port, struct session *session, struct connection_methods *methods) { + struct connection *c = malloc (sizeof (*c)); + memset (c, 0, sizeof (*c)); + struct hostent *h; + if (!(h = gethostbyname (host)) || h->h_addrtype != AF_INET || h->h_length != 4 || !h->h_addr_list || !h->h_addr) { + assert (0); + } + int fd; + assert ((fd = socket (AF_INET, SOCK_STREAM, 0)) != -1); + assert (fd >= 0 && fd < MAX_CONNECTIONS); + if (fd > max_connection_fd) { + max_connection_fd = fd; + } + int flags = -1; + setsockopt (fd, SOL_SOCKET, SO_REUSEADDR, &flags, sizeof (flags)); + setsockopt (fd, SOL_SOCKET, SO_KEEPALIVE, &flags, sizeof (flags)); + setsockopt (fd, IPPROTO_TCP, TCP_NODELAY, &flags, sizeof (flags)); + + struct sockaddr_in addr; + addr.sin_family = AF_INET; + addr.sin_port = htons (port); + addr.sin_addr.s_addr = inet_addr (host); + + + fcntl (fd, F_SETFL, O_NONBLOCK); + + if (connect (fd, (struct sockaddr *) &addr, sizeof (addr)) == -1) { + if (errno != EINPROGRESS) { + fprintf (stderr, "Can not connect to %s:%d %m\n", host, port); + close (fd); + free (c); + return 0; + } + } + + struct pollfd s; + s.fd = fd; + s.events = POLLOUT | POLLERR | POLLRDHUP | POLLHUP; + + if (poll (&s, 1, 10000) <= 0 || !(s.revents & POLLOUT)) { + perror ("poll"); + close (fd); + free (c); + return 0; + } + + c->session = session; + c->fd = fd; + c->ip = htonl (*(int *)h->h_addr); + c->flags = 0; + c->state = conn_ready; + c->methods = methods; + assert (!Connections[fd]); + Connections[fd] = c; + if (verbosity) { + fprintf (stderr, "connect to %s:%d successful\n", host, port); + } + if (c->methods->ready) { + c->methods->ready (c); + } + return c; +} + +void fail_connection (struct connection *c) { + struct connection_buffer *b = c->out_head; + while (b) { + struct connection_buffer *d = b; + b = b->next; + delete_connection_buffer (d); + } + b = c->in_head; + while (b) { + struct connection_buffer *d = b; + b = b->next; + delete_connection_buffer (d); + } + c->out_head = c->out_tail = c->in_head = c->in_tail = 0; + c->state = conn_failed; + c->out_bytes = c->in_bytes = 0; +} + +void try_write (struct connection *c) { + if (verbosity) { + fprintf (stderr, "try write: fd = %d\n", c->fd); + } + int x = 0; + while (c->out_head) { + int r = write (c->fd, c->out_head->rptr, c->out_head->wptr - c->out_head->rptr); + if (r >= 0) { + x += r; + c->out_head->rptr += r; + if (c->out_head->rptr != c->out_head->wptr) { + break; + } + struct connection_buffer *b = c->out_head; + c->out_head = b->next; + if (!c->out_head) { + c->out_tail = 0; + } + delete_connection_buffer (b); + } else { + if (errno != EAGAIN && errno != EWOULDBLOCK) { + fail_connection (c); + return; + } else { + break; + } + } + } + if (verbosity) { + fprintf (stderr, "Sent %d bytes to %d\n", x, c->fd); + } + c->out_bytes -= x; +} + +void hexdump (struct connection_buffer *b) { + int pos = 0; + int rem = 8; + while (b) { + unsigned char *c = b->rptr; + while (c != b->wptr) { + if (rem == 8) { + if (pos) { printf ("\n"); } + printf ("%04d", pos); + } + printf (" %02x", (int)*c); + rem --; + pos ++; + if (!rem) { + rem = 8; + } + c ++; + } + b = b->next; + } + printf ("\n"); + +} + +void try_rpc_read (struct connection *c) { + assert (c->in_head); + if (verbosity >= 4) { + hexdump (c->in_head); + } + + while (1) { + if (c->in_bytes < 1) { return; } + unsigned len = 0; + unsigned t = 0; + assert (read_in_lookup (c, &len, 1) == 1); + if (len >= 1 && len <= 0x7e) { + if (c->in_bytes < (int)(4 * len)) { return; } + } else { + if (c->in_bytes < 4) { return; } + assert (read_in_lookup (c, &len, 4) == 4); + len = (len >> 8); + if (c->in_bytes < (int)(4 * len)) { return; } + len = 0x7f; + } + + if (len >= 1 && len <= 0x7e) { + assert (read_in (c, &t, 1) == 1); + assert (t == len); + assert (len >= 1); + } else { + assert (len == 0x7f); + assert (read_in (c, &len, 4) == 4); + len = (len >> 8); + assert (len >= 1); + } + len *= 4; + int op; + assert (read_in_lookup (c, &op, 4) == 4); + c->methods->execute (c, op, len); + } +} + +void try_read (struct connection *c) { + if (verbosity) { + fprintf (stderr, "try read: fd = %d\n", c->fd); + } + if (!c->in_tail) { + c->in_head = c->in_tail = new_connection_buffer (1 << 20); + } + int x = 0; + while (1) { + int r = read (c->fd, c->in_tail->wptr, c->in_tail->end - c->in_tail->wptr); + if (r >= 0) { + c->in_tail->wptr += r; + x += r; + if (c->in_tail->wptr != c->in_tail->end) { + break; + } + struct connection_buffer *b = new_connection_buffer (1 << 20); + c->in_tail->next = b; + c->in_tail = b; + } else { + if (errno != EAGAIN && errno != EWOULDBLOCK) { + fail_connection (c); + return; + } else { + break; + } + } + } + if (verbosity) { + fprintf (stderr, "Received %d bytes from %d\n", x, c->fd); + } + c->in_bytes += x; + if (x) { + try_rpc_read (c); + } +} + +int connections_make_poll_array (struct pollfd *fds, int max) { + int _max = max; + int i; + for (i = 0; i <= max_connection_fd; i++) if (Connections[i] && Connections[i]->state != conn_failed) { + assert (max > 0); + struct connection *c = Connections[i]; + fds[0].fd = c->fd; + fds[0].events = POLLERR | POLLHUP | POLLRDHUP | POLLIN; + if (c->out_bytes || c->state == conn_connecting) { + fds[0].events |= POLLOUT; + } + fds ++; + max --; + } + if (verbosity >= 3) { + fprintf (stderr, "%d connections in poll\n", _max - max); + } + return _max - max; +} + +void connections_poll_result (struct pollfd *fds, int max) { + if (verbosity >= 2) { + fprintf (stderr, "connections_poll_result: max = %d\n", max); + } + int i; + for (i = 0; i < max; i++) { + struct connection *c = Connections[fds[i].fd]; + if (fds[i].revents & POLLIN) { + try_read (c); + } + if (fds[i].revents & (POLLHUP | POLLERR | POLLRDHUP)) { + if (verbosity) { + fprintf (stderr, "fail connection\n"); + } + fail_connection (c); + } else if (fds[i].revents & POLLOUT) { + if (c->state == conn_connecting) { + c->state = conn_ready; + } + if (c->out_bytes) { + try_write (c); + } + } + } +} + +int send_all_acks (struct session *S) { + clear_packet (); + out_int (tree_count_int (S->ack_tree)); + while (S->ack_tree) { + int x = tree_get_min_int (S->ack_tree); + out_int (x); + S->ack_tree = tree_delete_int (S->ack_tree, x); + } + encrypt_send_message (S->c, packet_buffer, packet_ptr - packet_buffer, 0); + return 0; +} + +void insert_seqno (struct session *S, int seqno) { + if (!S->ack_tree) { + S->ev.alarm = (void *)send_all_acks; + S->ev.self = (void *)S; + S->ev.timeout = get_double_time () + ACK_TIMEOUT; + insert_event_timer (&S->ev); + } + if (!tree_lookup_int (S->ack_tree, seqno)) { + S->ack_tree = tree_insert_int (S->ack_tree, seqno, lrand48 ()); + } +} + +extern struct dc *DC_list[]; + +struct dc *alloc_dc (int id, char *ip, int port) { + assert (!DC_list[id]); + struct dc *DC = malloc (sizeof (*DC)); + memset (DC, 0, sizeof (*DC)); + DC->id = id; + DC->ip = ip; + DC->port = port; + DC_list[id] = DC; + return DC; +} + +void dc_create_session (struct dc *DC) { + struct session *S = malloc (sizeof (*S)); + memset (S, 0, sizeof (*S)); + assert (RAND_pseudo_bytes ((unsigned char *) &S->session_id, 8) >= 0); + S->dc = DC; + S->c = create_connection (DC->ip, DC->port, S, &auth_methods); + assert (!DC->sessions[0]); + DC->sessions[0] = S; +} diff --git a/net.h b/net.h new file mode 100644 index 0000000..e934920 --- /dev/null +++ b/net.h @@ -0,0 +1,121 @@ +#ifndef __NET_H__ +#define __NET_H__ + +#include +struct dc; +#include "queries.h" +#define TG_SERVER "173.240.5.1" +//#define TG_SERVER "95.142.192.66" +#define TG_APP_HASH "3bc14c6455ef1595ec86a125762c3aad" +#define TG_APP_ID 51 + +#define ACK_TIMEOUT 60 +#define MAX_DC_ID 10 + +enum dc_state{ + st_init, + st_reqpq_sent, + st_reqdh_sent, + st_client_dh_sent, + st_authorized, + st_error +} ; + +struct connection; +struct connection_methods { + int (*ready) (struct connection *c); + int (*close) (struct connection *c); + int (*execute) (struct connection *c, int op, int len); +}; + + +#define MAX_DC_SESSIONS 3 + +struct session { + struct dc *dc; + long long session_id; + int seq_no; + struct connection *c; + struct tree_int *ack_tree; + struct event_timer ev; +}; + +struct dc { + int id; + int port; + int flags; + char *ip; + char *user; + struct session *sessions[MAX_DC_SESSIONS]; + char auth_key[256]; + long long auth_key_id; + long long server_salt; + + int server_time_delta; + double server_time_udelta; +}; + +#define DC_SERIALIZED_MAGIC 0x64582faa +struct dc_serialized { + int magic; + int port; + char ip[64]; + char user[64]; + char auth_key[256]; + long long auth_key_id, server_salt; + int authorized; +}; + +struct connection_buffer { + void *start; + void *end; + void *rptr; + void *wptr; + struct connection_buffer *next; +}; + +enum conn_state { + conn_none, + conn_connecting, + conn_ready, + conn_failed, + conn_stopped +}; + +struct connection { + int fd; + int ip; + int port; + int flags; + enum conn_state state; + int ipv6[4]; + struct connection_buffer *in_head; + struct connection_buffer *in_tail; + struct connection_buffer *out_head; + struct connection_buffer *out_tail; + int in_bytes; + int out_bytes; + int packet_num; + int out_packet_num; + struct connection_methods *methods; + struct session *session; + void *extra; +}; + +extern struct connection *Connections[]; + +int write_out (struct connection *c, const void *data, int len); +void flush_out (struct connection *c); +int read_in (struct connection *c, void *data, int len); + +void create_all_outbound_connections (void); + +struct connection *create_connection (const char *host, int port, struct session *session, struct connection_methods *methods); +int connections_make_poll_array (struct pollfd *fds, int max); +void connections_poll_result (struct pollfd *fds, int max); +void dc_create_session (struct dc *DC); +void insert_seqno (struct session *S, int seqno); +struct dc *alloc_dc (int id, char *ip, int port); + +#define GET_DC(c) (c->session->dc) +#endif diff --git a/queries.c b/queries.c new file mode 100644 index 0000000..13b21ca --- /dev/null +++ b/queries.c @@ -0,0 +1,404 @@ +#include +#include +#include +#include + +#include "include.h" +#include "mtproto-client.h" +#include "queries.h" +#include "tree.h" +#include "mtproto-common.h" +#include "telegram.h" +#include "loop.h" +#include "structures.h" +#include "interface.h" + +int verbosity; + +#define QUERY_TIMEOUT 0.3 + +#define memcmp8(a,b) memcmp ((a), (b), 8) +DEFINE_TREE (query, struct query *, memcmp8, 0) ; +struct tree_query *queries_tree; + +double get_double_time (void) { + struct timespec tv; + clock_gettime (CLOCK_REALTIME, &tv); + return tv.tv_sec + 1e-9 * tv.tv_nsec; +} + +struct query *query_get (long long id) { + return tree_lookup_query (queries_tree, (void *)&id); +} + +int alarm_query (struct query *q) { + assert (q); + return 0; +} + +struct query *send_query (struct dc *DC, int ints, void *data, struct query_methods *methods) { + assert (DC); + assert (DC->auth_key_id); + if (!DC->sessions[0]) { + dc_create_session (DC); + } + if (verbosity) { + fprintf (stderr, "Sending query of size %d to DC (%s:%d)\n", 4 * ints, DC->ip, DC->port); + } + struct query *q = malloc (sizeof (*q)); + q->data_len = ints; + q->data = malloc (4 * ints); + memcpy (q->data, data, 4 * ints); + q->msg_id = encrypt_send_message (DC->sessions[0]->c, data, ints, 1); + if (verbosity) { + fprintf (stderr, "Msg_id is %lld\n", q->msg_id); + } + q->methods = methods; + if (queries_tree) { + fprintf (stderr, "%lld %lld\n", q->msg_id, queries_tree->x->msg_id); + } + queries_tree = tree_insert_query (queries_tree, q, lrand48 ()); + + q->ev.alarm = (void *)alarm_query; + q->ev.timeout = get_double_time () + QUERY_TIMEOUT; + q->ev.self = (void *)q; + insert_event_timer (&q->ev); + return q; +} + +void query_ack (long long id) { + struct query *q = query_get (id); + if (q) { q->flags |= QUERY_ACK_RECEIVED; } +} + +void query_error (long long id) { + assert (fetch_int () == CODE_rpc_error); + int error_code = fetch_int (); + int error_len = prefetch_strlen (); + char *error = fetch_str (error_len); + if (verbosity) { + fprintf (stderr, "error for query #%lld: #%d :%.*s\n", id, error_code, error_len, error); + } + struct query *q = query_get (id); + if (!q) { + if (verbosity) { + fprintf (stderr, "No such query\n"); + } + } else { + remove_event_timer (&q->ev); + queries_tree = tree_delete_query (queries_tree, q); + if (q->methods && q->methods->on_error) { + q->methods->on_error (q, error_code, error_len, error); + } + free (q->data); + free (q); + } +} + +#define MAX_PACKED_SIZE (1 << 20) +static int packed_buffer[MAX_PACKED_SIZE / 4]; + +void query_result (long long id UU) { + if (verbosity) { + fprintf (stderr, "result for query #%lld\n", id); + } + if (verbosity >= 4) { + fprintf (stderr, "result: "); + hexdump_in (); + } + int op = prefetch_int (); + int *end = 0; + int *eend = 0; + if (op == CODE_gzip_packed) { + fetch_int (); + int l = prefetch_strlen (); + char *s = fetch_str (l); + size_t dl = MAX_PACKED_SIZE; + + z_stream strm = {0}; + assert (inflateInit2 (&strm, 16 + MAX_WBITS) == Z_OK); + strm.avail_in = l; + strm.next_in = (void *)s; + strm.avail_out = MAX_PACKED_SIZE; + strm.next_out = (void *)packed_buffer; + + int err = inflate (&strm, Z_FINISH); + if (verbosity) { + fprintf (stderr, "inflate error = %d\n", err); + fprintf (stderr, "inflated %d bytes\n", (int)strm.total_out); + } + end = in_ptr; + eend = in_end; + assert (dl % 4 == 0); + in_ptr = packed_buffer; + in_end = in_ptr + strm.total_out / 4; + if (verbosity >= 4) { + fprintf (stderr, "Unzipped data: "); + hexdump_in (); + } + } + struct query *q = query_get (id); + if (!q) { + if (verbosity) { + fprintf (stderr, "No such query\n"); + } + } else { + remove_event_timer (&q->ev); + queries_tree = tree_delete_query (queries_tree, q); + if (q->methods && q->methods->on_answer) { + q->methods->on_answer (q); + } + free (q->data); + free (q); + } + if (end) { + in_ptr = end; + in_end = eend; + } +} + +#define event_timer_cmp(a,b) ((a)->timeout > (b)->timeout ? 1 : ((a)->timeout < (b)->timeout ? -1 : (memcmp (a, b, sizeof (struct event_timer))))) +DEFINE_TREE (timer, struct event_timer *, event_timer_cmp, 0) +struct tree_timer *timer_tree; + +void insert_event_timer (struct event_timer *ev) { + return; + fprintf (stderr, "INSERT: %lf %p %p\n", ev->timeout, ev->self, ev->alarm); + tree_check_timer (timer_tree); + timer_tree = tree_insert_timer (timer_tree, ev, lrand48 ()); + tree_check_timer (timer_tree); +} + +void remove_event_timer (struct event_timer *ev) { + return; + fprintf (stderr, "REMOVE: %lf %p %p\n", ev->timeout, ev->self, ev->alarm); + tree_check_timer (timer_tree); + timer_tree = tree_delete_timer (timer_tree, ev); + tree_check_timer (timer_tree); +} + +double next_timer_in (void) { + if (!timer_tree) { return 1e100; } + return tree_get_min_timer (timer_tree)->timeout; +} + +void work_timers (void) { + double t = get_double_time (); + while (timer_tree) { + struct event_timer *ev = tree_get_min_timer (timer_tree); + assert (ev); + if (ev->timeout > t) { break; } + remove_event_timer (ev); + ev->alarm (ev->self); + } +} + +int max_chat_size; +int want_dc_num; +extern struct dc *DC_list[]; +extern struct dc *DC_working; + +int help_get_config_on_answer (struct query *q UU) { + assert (fetch_int () == CODE_config); + fetch_int (); + + unsigned test_mode = fetch_int (); + assert (test_mode == CODE_bool_true || test_mode == CODE_bool_false); + assert (test_mode == CODE_bool_false); + int this_dc = fetch_int (); + if (verbosity) { + fprintf (stderr, "this_dc = %d\n", this_dc); + } + assert (fetch_int () == CODE_vector); + int n = fetch_int (); + assert (n <= 10); + int i; + for (i = 0; i < n; i++) { + assert (fetch_int () == CODE_dc_option); + int id = fetch_int (); + int l1 = prefetch_strlen (); + char *name = fetch_str (l1); + int l2 = prefetch_strlen (); + char *ip = fetch_str (l2); + int port = fetch_int (); + if (verbosity) { + fprintf (stderr, "id = %d, name = %.*s ip = %.*s port = %d\n", id, l1, name, l2, ip, port); + } + if (!DC_list[id]) { + alloc_dc (id, strndup (ip, l2), port); + } + } + max_chat_size = fetch_int (); + if (verbosity >= 2) { + fprintf (stderr, "chat_size = %d\n", max_chat_size); + } + return 0; +} + +struct query_methods help_get_config_methods = { + .on_answer = help_get_config_on_answer +}; + +char *phone_code_hash; +int send_code_on_answer (struct query *q UU) { + assert (fetch_int () == CODE_auth_sent_code); + assert (fetch_int () == (int)CODE_bool_true); + int l = prefetch_strlen (); + char *s = fetch_str (l); + if (phone_code_hash) { + free (phone_code_hash); + } + phone_code_hash = strndup (s, l); + want_dc_num = -1; + return 0; +} + +int send_code_on_error (struct query *q UU, int error_code, int l, char *error) { + int s = strlen ("PHONE_MIGRATE_"); + if (l >= s && !memcmp (error, "PHONE_MIGRATE_", s)) { + int i = error[s] - '0'; + want_dc_num = i; + } else { + fprintf (stderr, "error_code = %d, error = %.*s\n", error_code, l, error); + assert (0); + } + return 0; +} + +struct query_methods send_code_methods = { + .on_answer = send_code_on_answer, + .on_error = send_code_on_error +}; + +int code_is_sent (void) { + return want_dc_num; +} + +int config_got (void) { + return DC_list[want_dc_num] != 0; +} + +char *suser; +extern int dc_working_num; +void do_send_code (const char *user) { + suser = strdup (user); + want_dc_num = 0; + clear_packet (); + out_int (CODE_auth_send_code); + out_string (user); + out_int (0); + out_int (TG_APP_ID); + out_string (TG_APP_HASH); + + send_query (DC_working, packet_ptr - packet_buffer, packet_buffer, &send_code_methods); + net_loop (0, code_is_sent); + if (want_dc_num == -1) { return; } + + if (DC_list[want_dc_num]) { + DC_working = DC_list[want_dc_num]; + if (!DC_working->auth_key_id) { + dc_authorize (DC_working); + } + if (!DC_working->sessions[0]) { + dc_create_session (DC_working); + } + dc_working_num = want_dc_num; + } else { + clear_packet (); + out_int (CODE_help_get_config); + send_query (DC_working, packet_ptr - packet_buffer, packet_buffer, &help_get_config_methods); + net_loop (0, config_got); + DC_working = DC_list[want_dc_num]; + if (!DC_working->auth_key_id) { + dc_authorize (DC_working); + } + if (!DC_working->sessions[0]) { + dc_create_session (DC_working); + } + dc_working_num = want_dc_num; + } + want_dc_num = 0; + clear_packet (); + out_int (CODE_auth_send_code); + out_string (user); + out_int (0); + out_int (TG_APP_ID); + out_string (TG_APP_HASH); + + send_query (DC_working, packet_ptr - packet_buffer, packet_buffer, &send_code_methods); + net_loop (0, code_is_sent); + assert (want_dc_num == -1); +} + +int sign_in_ok; +int sign_in_is_ok (void) { + return sign_in_ok; +} + +struct user User; + +int sign_in_on_answer (struct query *q UU) { + assert (fetch_int () == (int)CODE_auth_authorization); + int expires = fetch_int (); + fetch_user (&User); + sign_in_ok = 1; + if (verbosity) { + fprintf (stderr, "authorized successfully: name = '%s %s', phone = '%s', expires = %d\n", User.first_name, User.last_name, User.phone, (int)(expires - get_double_time ())); + } + return 0; +} + +int sign_in_on_error (struct query *q UU, int error_code, int l, char *error) { + fprintf (stderr, "error_code = %d, error = %.*s\n", error_code, l, error); + sign_in_ok = -1; + return 0; +} + +struct query_methods sign_in_methods = { + .on_answer = sign_in_on_answer, + .on_error = sign_in_on_error +}; + +int do_send_code_result (const char *code) { + clear_packet (); + out_int (CODE_sign_in); + out_string (suser); + out_string (phone_code_hash); + out_string (code); + send_query (DC_working, packet_ptr - packet_buffer, packet_buffer, &sign_in_methods); + sign_in_ok = 0; + net_loop (0, sign_in_is_ok); + return sign_in_ok; +} + +int get_contacts_on_answer (struct query *q UU) { + assert (fetch_int () == (int)CODE_contacts_contacts); + assert (fetch_int () == CODE_vector); + int n = fetch_int (); + int i; + for (i = 0; i < n; i++) { + assert (fetch_int () == (int)CODE_contact); + fetch_int (); // id + fetch_int (); // mutual + } + assert (fetch_int () == CODE_vector); + n = fetch_int (); + for (i = 0; i < n; i++) { + struct user User; + fetch_user (&User); + rprintf ("User: id = %d, first_name = %s, last_name = %s\n", User.id, User.first_name, User.last_name); + } + return 0; +} + +struct query_methods get_contacts_methods = { + .on_answer = get_contacts_on_answer, +}; + + +void do_update_contact_list (void) { + clear_packet (); + out_int (CODE_contacts_get_contacts); + out_string (""); + send_query (DC_working, packet_ptr - packet_buffer, packet_buffer, &get_contacts_methods); +} diff --git a/queries.h b/queries.h new file mode 100644 index 0000000..3f812e2 --- /dev/null +++ b/queries.h @@ -0,0 +1,47 @@ +#include "net.h" +#ifndef __QUERIES_H__ +#define __QUERIES_H__ + +#define QUERY_ACK_RECEIVED 1 + +struct query; +struct query_methods { + int (*on_answer)(struct query *q); + int (*on_error)(struct query *q, int error_code, int len, char *error); + int (*on_timeout)(struct query *q); +}; + +struct event_timer { + double timeout; + int (*alarm)(void *self); + void *self; +}; + +struct query { + long long msg_id; + int data_len; + int flags; + void *data; + struct query_methods *methods; + struct event_timer ev; +}; + + +struct query *send_query (struct dc *DC, int len, void *data, struct query_methods *methods); +void query_ack (long long id); +void query_error (long long id); +void query_result (long long id); + +void insert_event_timer (struct event_timer *ev); +void remove_event_timer (struct event_timer *ev); +double next_timer_in (void); +void work_timers (void); + +extern struct query_methods help_get_config_methods; + +void do_send_code (const char *user); +int do_send_code_result (const char *code); +double get_double_time (void); + +void do_update_contact_list (void); +#endif diff --git a/structures.c b/structures.c new file mode 100644 index 0000000..aad5eb2 --- /dev/null +++ b/structures.c @@ -0,0 +1,81 @@ +#include +#include "structures.h" +#include "mtproto-common.h" + +void fetch_file_location (struct file_location *loc) { + int x = fetch_int (); + if (x == CODE_file_location_unavailable) { + loc->dc = -1; + loc->volume = fetch_long (); + loc->local_id = fetch_int (); + loc->secret = fetch_long (); + } else { + assert (x == CODE_file_location); + loc->dc = fetch_int ();; + loc->volume = fetch_long (); + loc->local_id = fetch_int (); + loc->secret = fetch_long (); + } +} + +void fetch_user_status (struct user_status *S) { + int x = fetch_int (); + switch (x) { + case CODE_user_status_empty: + S->online = 0; + break; + case CODE_user_status_online: + S->online = 1; + S->when = fetch_int (); + break; + case CODE_user_status_offline: + S->online = -1; + S->when = fetch_int (); + break; + default: + assert (0); + } +} + +void fetch_user (struct user *U) { + memset (U, 0, sizeof (*U)); + unsigned x = fetch_int (); + assert (x == CODE_user_empty || x == CODE_user_self || x == CODE_user_contact || x == CODE_user_request || x == CODE_user_foreign || x == CODE_user_deleted); + U->id = fetch_int (); + if (x == CODE_user_empty) { + U->flags = 1; + return; + } + U->first_name = fetch_str_dup (); + U->last_name = fetch_str_dup (); + if (x == CODE_user_deleted) { + U->flags = 2; + return; + } + if (x == CODE_user_self) { + U->flags = 4; + } else { + U->access_hash = fetch_long (); + } + if (x == CODE_user_foreign) { + U->flags |= 8; + } else { + U->phone = fetch_str_dup (); + } + unsigned y = fetch_int (); + if (y == CODE_user_profile_photo_empty) { + U->photo_small.dc = -2; + U->photo_big.dc = -2; + } else { + assert (y == CODE_user_profile_photo); + fetch_file_location (&U->photo_small); + fetch_file_location (&U->photo_big); + } + fetch_user_status (&U->status); + if (x == CODE_user_self) { + assert (fetch_int () == (int)CODE_bool_false); + } + if (x == CODE_user_contact) { + U->flags |= 16; + } +} diff --git a/structures.h b/structures.h new file mode 100644 index 0000000..bdb34bc --- /dev/null +++ b/structures.h @@ -0,0 +1,31 @@ +#ifndef __STRUCTURES_H__ +#define __STRUCTURES_H__ + + +struct file_location { + int dc; + long long volume; + int local_id; + long long secret; +}; + +struct user_status { + int online; + int when; +}; + +struct user { + int id; + int flags; + char *first_name; + char *last_name; + char *phone; + long long access_hash; + struct file_location photo_big; + struct file_location photo_small; + struct user_status status; +}; + +void fetch_file_location (struct file_location *loc); +void fetch_user (struct user *U); +#endif diff --git a/telegram.h b/telegram.h index e69de29..3958159 100644 --- a/telegram.h +++ b/telegram.h @@ -0,0 +1 @@ +#define MAX_DC_NUM 9 diff --git a/tree.h b/tree.h new file mode 100644 index 0000000..3145139 --- /dev/null +++ b/tree.h @@ -0,0 +1,116 @@ +#ifndef __TREE_H__ +#define __TREE_H__ +#include + +#include +#include +#include + +#define DEFINE_TREE(X_NAME, X_TYPE, X_CMP, X_UNSET) \ +struct tree_ ## X_NAME { \ + struct tree_ ## X_NAME *left, *right;\ + X_TYPE x;\ + int y;\ +};\ +\ +struct tree_ ## X_NAME *new_tree_node_ ## X_NAME (X_TYPE x, int y) {\ + struct tree_ ## X_NAME *T = malloc (sizeof (*T));\ + T->x = x;\ + T->y = y;\ + T->left = T->right = 0;\ + return T;\ +}\ +\ +void delete_tree_node_ ## X_NAME (struct tree_ ## X_NAME *T) {\ + free (T);\ +}\ +\ +void tree_split_ ## X_NAME (struct tree_ ## X_NAME *T, X_TYPE x, struct tree_ ## X_NAME **L, struct tree_ ## X_NAME **R) {\ + if (!T) {\ + *L = *R = 0;\ + } else {\ + int c = X_CMP (x, T->x);\ + if (c < 0) {\ + tree_split_ ## X_NAME (T->left, x, L, &T->left);\ + *R = T;\ + } else {\ + tree_split_ ## X_NAME (T->right, x, &T->right, R);\ + *L = T;\ + }\ + }\ +}\ +\ +struct tree_ ## X_NAME *tree_insert_ ## X_NAME (struct tree_ ## X_NAME *T, X_TYPE x, int y) {\ + if (!T) {\ + return new_tree_node_ ## X_NAME (x, y);\ + } else {\ + if (y > T->y) {\ + struct tree_ ## X_NAME *N = new_tree_node_ ## X_NAME (x, y);\ + tree_split_ ## X_NAME (T, x, &N->left, &N->right);\ + return N;\ + } else {\ + int c = X_CMP (x, T->x);\ + assert (c);\ + return tree_insert_ ## X_NAME (c < 0 ? T->left : T->right, x, y);\ + }\ + }\ +}\ +\ +struct tree_ ## X_NAME *tree_merge_ ## X_NAME (struct tree_ ## X_NAME *L, struct tree_ ## X_NAME *R) {\ + if (!L || !R) {\ + return L ? L : R;\ + } else {\ + if (L->y > R->y) {\ + L->right = tree_merge_ ## X_NAME (L->right, R);\ + return L;\ + } else {\ + R->left = tree_merge_ ## X_NAME (L, R->left);\ + return R;\ + }\ + }\ +}\ +\ +struct tree_ ## X_NAME *tree_delete_ ## X_NAME (struct tree_ ## X_NAME *T, X_TYPE x) {\ + assert (T);\ + int c = X_CMP (x, T->x);\ + if (!c) {\ + struct tree_ ## X_NAME *N = tree_merge_ ## X_NAME (T->left, T->right);\ + delete_tree_node_ ## X_NAME (T);\ + return N;\ + } else {\ + return tree_delete_ ## X_NAME (c < 0 ? T->left : T->right, x);\ + }\ +}\ +\ +X_TYPE tree_get_min_ ## X_NAME (struct tree_ ## X_NAME *T) {\ + if (!T) { return X_UNSET; } \ + while (T->left) { T = T->left; }\ + return T->x; \ +} \ +\ +X_TYPE tree_lookup_ ## X_NAME (struct tree_ ## X_NAME *T, X_TYPE x) {\ + int c;\ + while (T && (c = X_CMP (x, T->x))) {\ + T = (c < 0 ? T->left : T->right);\ + }\ + return T ? T->x : X_UNSET;\ +}\ +\ +int tree_count_ ## X_NAME (struct tree_ ## X_NAME *T) { \ + if (!T) { return 0; }\ + return 1 + tree_count_ ## X_NAME (T->left) + tree_count_ ## X_NAME (T->right); \ +}\ +void tree_check_ ## X_NAME (struct tree_ ## X_NAME *T) { \ + if (!T) { return; }\ + if (T->left) { \ + assert (T->left->y <= T->y);\ + assert (X_CMP (T->left->x, T->x) < 0); \ + }\ + if (T->right) { \ + assert (T->right->y <= T->y);\ + assert (X_CMP (T->right->x, T->x) > 0); \ + }\ +}\ + +#define int_cmp(a,b) ((a) - (b)) +#endif