/**********
 * Copyright (c) 2003-2005 Greg Parker.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY GREG PARKER ``AS IS'' AND ANY EXPRESS OR
 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 **********/
/*
 * Copyright (c) 2001 Markus Friedl.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#include "includes.h"
#include "transport.h"

#include "session.h"
#include "connection.h"
#include "packetizer.h"
#include "ssh2.h"
#include "keyfile.h"
#include "data/prefs.h"
#include "openssh/buffer.h"
#include "openssh/bufaux.h"
#include "openssh/dh.h"
#include "openssh/kex.h"
#include "openssh/key.h"
#include "openssh/match.h"
#include "crypto/rand.h"

static int enc_init(Enc *enc, char *cipher) TRANSPORT_SEGMENT;
static int comp_init(Comp **comp, void **comp_ctx, char *compression) TRANSPORT_SEGMENT;
static int decomp_init(Comp **comp, void **comp_ctx, char *compression) TRANSPORT_SEGMENT;
static int same_preferred_algorithm(char *alg1, char *alg2) TRANSPORT_SEGMENT;
static void transport_set_supported_algorithms(ssh_session_t *ss) TRANSPORT_SEGMENT;
static void transport_free_supported_algorithms(ssh_session_t *ss) TRANSPORT_SEGMENT;
static uint16_t transport_choose_kex_key_length(ssh_session_t *ss) TRANSPORT_SEGMENT;
static void transport_send_kexinit(ssh_session_t *ss) TRANSPORT_SEGMENT;
static void transport_send_dh_init(ssh_session_t *ss, BIGNUM *publicKey) TRANSPORT_SEGMENT;
static void transport_send_newkeys(ssh_session_t *ss) TRANSPORT_SEGMENT;
static void transport_send_unimplemented(ssh_session_t *ss, uint32_t seq) TRANSPORT_SEGMENT;
static void transport_send_disconnect(ssh_session_t *ss, uint32_t code, char *str) TRANSPORT_SEGMENT;
static int transport_handle_kexinit_kex(ssh_session_t *ss, Buffer *packet, int *guessedWrong) TRANSPORT_SEGMENT;
static int transport_handle_kexinit_server_host_key(ssh_session_t *ss, Buffer *packet, int *guessedWrong) TRANSPORT_SEGMENT;
static int transport_handle_kexinit_cipher(ssh_session_t *ss, Buffer *packet) TRANSPORT_SEGMENT;
static int transport_handle_kexinit_mac(ssh_session_t *ss, Buffer *packet) TRANSPORT_SEGMENT;
static int transport_handle_kexinit_compression(ssh_session_t *ss, Buffer *packet) TRANSPORT_SEGMENT;
static int transport_handle_kexinit_languages(ssh_session_t *ss, Buffer *packet) TRANSPORT_SEGMENT;
static transport_state_t transport_handle_ignore(ssh_session_t *ss, Buffer *packet) TRANSPORT_SEGMENT;
static transport_state_t transport_handle_debug(ssh_session_t *ss, Buffer *packet) TRANSPORT_SEGMENT;
static transport_state_t transport_handle_disconnect(ssh_session_t *ss, Buffer *packet) TRANSPORT_SEGMENT;
static transport_state_t transport_handle_unimplemented(ssh_session_t *ss, Buffer *packet) TRANSPORT_SEGMENT;
static transport_state_t transport_handle_kexinit(ssh_session_t *ss, Buffer *packet) TRANSPORT_SEGMENT;
static transport_state_t transport_handle_dh_reply(ssh_session_t *ss, Buffer *packet) TRANSPORT_SEGMENT;
static transport_state_t transport_handle_newkeys(ssh_session_t *ss, Buffer *packet) TRANSPORT_SEGMENT;
static transport_state_t transport_handle_re_kexinit(ssh_session_t *ss, Buffer *packet) TRANSPORT_SEGMENT;
static transport_state_t transport_state_waiting_for_kexinit(ssh_session_t *ss, uint8_t msg, Buffer *packet) TRANSPORT_SEGMENT;
static transport_state_t transport_state_waiting_for_dh_reply(ssh_session_t *ss, uint8_t msg, Buffer *packet) TRANSPORT_SEGMENT;
static transport_state_t transport_state_waiting_for_newkeys(ssh_session_t *ss, uint8_t msg, Buffer *packet) TRANSPORT_SEGMENT;
static transport_state_t transport_state_running(ssh_session_t *ss, uint8_t msg, Buffer *packet) TRANSPORT_SEGMENT;
static void transport_cleanup(ssh_session_t *ss) TRANSPORT_SEGMENT;
static void transport_send_queued_payloads(struct ssh_session_t *ss) TRANSPORT_SEGMENT;

static int 
enc_init(Enc *enc, char *cipher)
{
    enc->cipher = cipher_by_name(cipher);
    if (!enc->cipher) return -1;
    enc->name = cipher_name(cipher_get_number(enc->cipher)); // fixme lame
    enc->enabled = 0;
    enc->iv = NULL;
    enc->key = NULL;
    enc->key_len = cipher_keylen(enc->cipher);
    enc->block_size = cipher_blocksize(enc->cipher);
    transport_log("init cipher done", cipher);
    return 0;
}

static int 
comp_init(Comp **comp, void **comp_ctx, char *name)
{
    if (0 == strcmp(name, "zlib")) {
        *comp = &ssh_zlib_compress;
        *comp_ctx = (*comp)->init();
        return 0;
    } else if (0 == strcmp(name, "none")) {
        *comp = NULL;
        *comp_ctx = NULL;
        return 0;
    } else {
        return -1;
    }
}

static int 
decomp_init(Comp **comp, void **comp_ctx, char *name)
{
    if (0 == strcmp(name, "zlib")) {
        *comp = &ssh_zlib_decompress;
        *comp_ctx = (*comp)->init();
        return 0;
    } else if (0 == strcmp(name, "none")) {
        *comp = NULL;
        *comp_ctx = NULL;
        return 0;
    } else {
        return -1;
    }
}

// also called by packetizer
void keys_free(Keys *keys)
{
    if (keys->enc.iv)   xfree(keys->enc.iv);
    if (keys->enc.key)  xfree(keys->enc.key);
    if (keys->mac.key)  xfree(keys->mac.key);    
    if (keys->comp  &&  keys->comp_ctx) keys->comp->cleanup(keys->comp_ctx);
}

// return TRUE iff the first algorithm name in each of 
// algorithm lists alg1 and alg2 are the same
static int same_preferred_algorithm(char *alg1, char *alg2)
{
    char *end1, *end2;
    uint16_t len1, len2;
    
    end1 = strchr(alg1, ',');
    if (end1) len1 = end1 - alg1;
    else len1 = strlen(alg1);

    end2 = strchr(alg2, ',');
    if (end2) len2 = end2 - alg2;
    else len2 = strlen(alg2);

    if (len1 != len2) return 0;
    else return (0 == strncmp(alg1, alg2, len1));
}


static void transport_set_supported_algorithms(ssh_session_t *ss)
{
    char *buf;

    // fixme update here for more algorithms
    ss->t.supported_kex_algorithms = 
        arena_strdup("diffie-hellman-group1-sha1");
    ss->t.supported_hostkey_algorithms = 
        arena_strdup("ssh-rsa,ssh-dss");
    ss->t.supported_mac_algorithms = 
        arena_strdup("hmac-sha1");

    // cipher
    buf = arena_calloc(1+strlen("aes128-cbc,")+strlen("3des-cbc,"));
    if (PrefsGetInt(prefCipherAES128CBC, 1)) strcat(buf, "aes128-cbc,");
    if (PrefsGetInt(prefCipher3DESCBC, 1)) strcat(buf, "3des-cbc,");
    // use 3des if nothing is specified
    if (strlen(buf) == 0) strcat(buf, "3des-cbc,");
    buf[strlen(buf)-1] = '\0'; // squish trailing comma
    ss->t.supported_cipher_algorithms = buf;

    // compression
    if (PrefsGetInt(prefCompressZLib, defaultCompressZLib)) {
        // zlib preferred
        ss->t.supported_compression_algorithms = arena_strdup("zlib,none");
    } else {
        // no zlib preferred, but still allow it if required by server
        ss->t.supported_compression_algorithms = arena_strdup("none, zlib");
    }
}


static void transport_free_supported_algorithms(ssh_session_t *ss)
{
    if (ss->t.supported_kex_algorithms) 
        arena_free(ss->t.supported_kex_algorithms);
    if (ss->t.supported_hostkey_algorithms) 
        arena_free(ss->t.supported_hostkey_algorithms);
    if (ss->t.supported_cipher_algorithms) 
        arena_free(ss->t.supported_cipher_algorithms);
    if (ss->t.supported_mac_algorithms) 
        arena_free(ss->t.supported_mac_algorithms);
    if (ss->t.supported_compression_algorithms) 
        arena_free(ss->t.supported_compression_algorithms);
}


static uint16_t transport_choose_kex_key_length(ssh_session_t *ss)
{
    // find longest key needed
    uint16_t need = 0;
    need = MAX(need, ss->t.newInKeys.enc.key_len);
    need = MAX(need, ss->t.newInKeys.enc.block_size);
    need = MAX(need, ss->t.newInKeys.mac.key_len);
    need = MAX(need, ss->t.newOutKeys.enc.key_len);
    need = MAX(need, ss->t.newOutKeys.enc.block_size);
    need = MAX(need, ss->t.newOutKeys.mac.key_len);
    return need;
}


static void transport_send_kexinit(ssh_session_t *ss)
{
    int i;
    void *p;
    Buffer *payload = &ss->t.local_kexinit;
    if (!payload->buf) buffer_init(payload);
    buffer_clear(payload);

    // msg
    buffer_put_char(payload, SSH2_MSG_KEXINIT);

    // 16-byte random cookie
    p = buffer_append_space(payload, 16);
    RAND_bytes(p, 16);
    
    // kex_algorithms
    buffer_put_cstring(payload, ss->t.supported_kex_algorithms);
    
    // server_host_key_algorithms
    buffer_put_cstring(payload, ss->t.supported_hostkey_algorithms);

    // encryption_algorithms (2-way)
    buffer_put_cstring(payload, ss->t.supported_cipher_algorithms);
    buffer_put_cstring(payload, ss->t.supported_cipher_algorithms);

    // mac_algorithms (2-way)
    buffer_put_cstring(payload, ss->t.supported_mac_algorithms);
    buffer_put_cstring(payload, ss->t.supported_mac_algorithms);

    // compression (2-way)
    buffer_put_cstring(payload, ss->t.supported_compression_algorithms);
    buffer_put_cstring(payload, ss->t.supported_compression_algorithms);

    // languages (2-way)
    buffer_put_cstring(payload, "");
    buffer_put_cstring(payload, "");

    // first-kex-packet-follows
    buffer_put_char(payload, 0);

    // 4-byte reserved
    buffer_put_int(payload, 0);

    packetizer_send_payload(ss, payload);

    // DO NOT clear payload here - local_kexinit points to it for DH later

    transport_log("transport: sent kexinit\r\n");
}


static void transport_send_dh_init(ssh_session_t *ss, BIGNUM *publicKey)
{
    Buffer payload;

    if (!ss->t.rekey) printf("   Exchanging keys...\r\n");

    buffer_init(&payload);
    buffer_put_char(&payload, SSH2_MSG_KEXDH_INIT);
    buffer_put_bignum2(&payload, publicKey);
    packetizer_send_payload(ss, &payload);
    buffer_free(&payload);
    transport_log("transport: sent dh-init\r\n");
}


static void transport_send_newkeys(ssh_session_t *ss)
{
    Buffer packet;
    buffer_init(&packet);
    buffer_put_char(&packet, SSH2_MSG_NEWKEYS);
    packetizer_send_payload(ss, &packet);
    buffer_free(&packet);
    transport_log("transport: sent newkeys\r\n");
}


static void transport_send_unimplemented(ssh_session_t *ss, uint32_t seq)
{
    Buffer packet;
    buffer_init(&packet);
    buffer_put_char(&packet, SSH2_MSG_UNIMPLEMENTED);
    buffer_put_int(&packet, seq);
    packetizer_send_payload(ss, &packet);
    buffer_free(&packet);
    transport_log("transport (state %d): sent unimplemented\r\n", ss->t.state);
}


static void transport_send_disconnect(ssh_session_t *ss, 
                                      uint32_t code, char *str)
{
    Buffer packet;
    buffer_init(&packet);
    buffer_put_char(&packet, SSH2_MSG_DISCONNECT);
    buffer_put_int(&packet, code);
    buffer_put_cstring(&packet, str);
    buffer_put_cstring(&packet, ""); // language
    packetizer_send_payload(ss, &packet);
    buffer_free(&packet);
    transport_log("transport (state %d): sent disconnect\r\n", ss->t.state);
}


static int 
transport_handle_kexinit_kex(ssh_session_t *ss, Buffer *packet, int *guessedWrong)
{
    uint16_t len;
    uint8_t *str = NULL;
    char *kex_name = NULL;

    // kex_algorithms
    str = buffer_get_string(packet, &len);
    kex_name = match_list(ss->t.supported_kex_algorithms, str, NULL);
    if (!kex_name) {
        // server doesn't support our kex algorithm - bail
        printf("\r\ntransport: No acceptable key exchange algorithm (server offered '%s', pssh supports '%s')\r\n", str, ss->t.supported_kex_algorithms);
        arena_free(str);
        return 0;
    } else {
        // kex kex_name OK
    }

    // Check for incorrect algorithm guess
    if (! same_preferred_algorithm(ss->t.supported_kex_algorithms, str)) {
        *guessedWrong = 1;
    }

    if (kex_name) arena_free(kex_name);
    if (str) arena_free(str);
    return 1;
}
    


static int 
transport_handle_kexinit_server_host_key(ssh_session_t *ss, Buffer *packet, int *guessedWrong)
{
    uint16_t len;
    uint8_t *str = NULL;
    char *hostkey_name = NULL;

    // server_host_key_algorithms
    str = buffer_get_string(packet, &len);
    hostkey_name = match_list(ss->t.supported_hostkey_algorithms, str, NULL);
    if (!hostkey_name) {
        // server doesn't support our host key algorithm - bail
        printf("\r\ntransport: No acceptable host key algorithm (server offered '%s', pssh supports '%s')\r\n", str, ss->t.supported_hostkey_algorithms);
        arena_free(str);
        return 0;
    }
    else {
        // server host key algorithm ok
        ss->t.hostkey_type = key_type_from_name(hostkey_name);
    }

    // Check for incorrect algorithm guess
    if (! same_preferred_algorithm(ss->t.supported_hostkey_algorithms, str)) {
        *guessedWrong = 1;
    }

    if (hostkey_name) arena_free(hostkey_name);
    if (str) arena_free(str);
    return 1;
}


static int 
transport_handle_kexinit_cipher(ssh_session_t *ss, Buffer *packet)
{
    uint16_t len;
    uint8_t *str = NULL;
    char *cipher_name = NULL;

    // encryption_algorithm (client to server)
    str = buffer_get_string(packet, &len);
    cipher_name = match_list(ss->t.supported_cipher_algorithms, str, NULL);
    if (!cipher_name) {
        // server doesn't support our ciphers - bail
        printf("\r\ntransport: No acceptable cipher (server offered '%s', pssh supports '%s')\r\n", str, ss->t.supported_cipher_algorithms);
        arena_free(str);
        return 0;
    }
    else {
        // cipher cipher_name OK
        if (!ss->t.rekey) printf("%s", cipher_name);
        if (enc_init(&ss->t.newOutKeys.enc, cipher_name) < 0) {
            fatal("transport: failed to init cipher '%s' (out)", cipher_name);
        }
    }

    if (cipher_name) { arena_free(cipher_name); cipher_name = NULL; }
    if (str) { arena_free(str); str = NULL; }


    // encryption_algorithm (server to client)
    str = buffer_get_string(packet, &len);
    cipher_name = match_list(ss->t.supported_cipher_algorithms, str, NULL);
    if (!cipher_name) {
        printf("\r\ntransport: No acceptable cipher (server offered '%s', pssh supports '%s')\r\n", str, ss->t.supported_cipher_algorithms);
        arena_free(str);
        return 0;
    }
    else {
        // cipher cipher_name OK
        if (enc_init(&ss->t.newInKeys.enc, cipher_name) < 0) {
            fatal("transport: failed to init cipher '%s' (in)", cipher_name);
        }
    }

    if (cipher_name) arena_free(cipher_name);
    if (str) arena_free(str);
    return 1;
}


static int 
transport_handle_kexinit_mac(ssh_session_t *ss, Buffer *packet)
{
    uint16_t len;
    uint8_t *str = NULL;
    char *mac_name = NULL;

    // mac_algorithm (client to server)
    str = buffer_get_string(packet, &len);
    mac_name = match_list(ss->t.supported_mac_algorithms, str, NULL);
    if (!mac_name) {
        // server doesn't support our MAC algorithm
        printf("\r\ntransport: No acceptable MAC algorithm (server offered '%s', pssh supports '%s')\r\n", str, ss->t.supported_mac_algorithms);
        arena_free(str);
        return 0;
    }
    else {
        // MAC mac_name OK
        if (!ss->t.rekey) printf(" %s", mac_name);
        if (mac_init(&ss->t.newOutKeys.mac, mac_name) < 0) {
            fatal("transport: failed to init mac '%s' (out)", mac_name);
        }
    }

    if (mac_name) { arena_free(mac_name); mac_name = NULL; }
    if (str) { arena_free(str); str = NULL; }


    // mac_algorithm (server to client)
    str = buffer_get_string(packet, &len);
    mac_name = match_list(ss->t.supported_mac_algorithms, str, NULL);
    if (!mac_name) {
        // server doesn't support our MAC algorithm
        printf("\r\ntransport: No acceptable MAC algorithm (server offered '%s', pssh supports '%s')\r\n", str, ss->t.supported_mac_algorithms);
        arena_free(str);
        return 0;
    }
    else {
        // MAC mac_name OK
        if (mac_init(&ss->t.newInKeys.mac, mac_name) < 0) {
            fatal("transport: failed to init mac '%s' (in)", mac_name);
        }
    }

    if (mac_name) { arena_free(mac_name); mac_name = NULL; }
    if (str) { arena_free(str); str = NULL; }
    return 1;
}


static int 
transport_handle_kexinit_compression(ssh_session_t *ss, Buffer *packet)
{
    uint16_t len;
    uint8_t *str;
    char *compress_name;
 
    // compression_algorithm (client to server)
    str = buffer_get_string(packet, &len);
    compress_name = match_list(ss->t.supported_compression_algorithms, str, NULL);
    if (!compress_name) {
        // server doesn't support our compression algorithm
        printf("\r\ntransport: No acceptable compression algorithm (server offered '%s', pssh supports '%s')\r\n", str, ss->t.supported_compression_algorithms);
        arena_free(str);
        return 0;
    }
    else {
        // compress compress_name OK
        if (!ss->t.rekey  &&  0 != strcmp(compress_name, "none")) {
            printf(" %s", compress_name);
        }
        if (comp_init(&ss->t.newOutKeys.comp, &ss->t.newOutKeys.comp_ctx, compress_name) < 0) {
            fatal("transport: failed to init compression '%s' (out)", compress_name);
        }
    }

    if (compress_name) { arena_free(compress_name); compress_name = NULL; }
    if (str) { arena_free(str); str = NULL; }


    // compression_algorithm (server to client)
    str = buffer_get_string(packet, &len);
    compress_name = match_list(ss->t.supported_compression_algorithms, str, NULL);
    if (!compress_name) {
        // server doesn't support our compression algorithm
        printf("\r\ntransport: No acceptable compression algorithm (server offered '%s', pssh supports '%s')\r\n", str, ss->t.supported_compression_algorithms);
        arena_free(str);
        return 0;
    }
    else {
        // compression compress_name OK
        if (decomp_init(&ss->t.newInKeys.comp, &ss->t.newInKeys.comp_ctx, compress_name) < 0) {
            fatal("transport: failed to init compression '%s' (in)", compress_name);
        }
    }

    if (compress_name) { arena_free(compress_name); compress_name = NULL; }
    if (str) { arena_free(str); str = NULL; }
    return 1;
}


static int 
transport_handle_kexinit_languages(ssh_session_t *ss, Buffer *packet)
{
    uint16_t len;
    uint8_t *str;

    // languages (2-way) (ignore completely)
    str = buffer_get_string(packet, &len);
    arena_free(str);
    str = buffer_get_string(packet, &len);
    arena_free(str);
    return 1;
}



static transport_state_t 
transport_handle_ignore(ssh_session_t *ss, Buffer *packet)
{
    /* WARNING WARNING WARNING
       Some SSH servers (including SSH-1.99-Cisco-1.25) send 
       ignore packets that don't conform to draft-ietf-secsh-transport-17.
       OpenSSH and PuTTY both truly ignore these packets. 
    uint16_t len;
    uint8_t *str = buffer_get_string(packet, &len);
    buffer_require_empty(packet);
    arena_free(str);
    */
    return ss->t.state;
}


static transport_state_t 
transport_handle_debug(ssh_session_t *ss, Buffer *packet)
{
    uint16_t len, languageLen;
    uint8_t alwaysDisplay = buffer_get_char(packet);
    uint8_t *str = buffer_get_string(packet, &len);
    uint8_t *languageStr = buffer_get_string(packet, &languageLen);
    buffer_require_empty(packet);

    if (alwaysDisplay) {
        transport_log("transport: received debug (display): '%s'\r\n", str);
    } else {
        transport_log("transport: received debug (don't display): '%s'\r\n", str);
    }

    arena_free(str);
    arena_free(languageStr);
    return ss->t.state;
}


static transport_state_t 
transport_handle_disconnect(ssh_session_t *ss, Buffer *packet)
{
    uint16_t len;
    uint32_t code = buffer_get_int(packet);
    uint8_t *str = buffer_get_string(packet, &len);

    // don't bother reading language, we're about to die anyway
    // don't bother verifying empty, we're about to die anyway

    printf("transport: received disconnect (reason %ld '%s')\r\n", code, str);

    ssh_kill(ss);
    return ss->t.state;
}


static transport_state_t 
transport_handle_unimplemented(ssh_session_t *ss, Buffer *packet)
{
    // Currently, all client-sent messages are required to be implemented 
    // by ssh spec. If that isn't true later, we could implement some sort 
    // of transport_expect_unimplemented function that specifies a callback.

    // don't bother requiring empty packet, we're about to die anyway
    printf("transport (state %d): received unimplemented\r\n", ss->t.state);
    ssh_kill(ss);
    return ss->t.state;
}


static transport_state_t 
transport_handle_kexinit(ssh_session_t *ss, Buffer *packet)
{
    int algOK;
    uint8_t guess;
    int guessedWrong;
    Buffer *rkex = &ss->t.remote_kexinit;

    // save kexinit packet for DH
    if (!rkex->buf) buffer_init(rkex);
    buffer_clear(rkex);
    // prepend message field (but not length field!)
    buffer_put_char(rkex, SSH2_MSG_KEXINIT);
    buffer_append(rkex, buffer_ptr(packet), buffer_len(packet));

    buffer_consume(packet, 16); // 16-byte random cookie

    // Assume remote's protocol guess is correct, if any
    // Guess is WRONG if best kex or best server_host_key algorithms differ
    guessedWrong = 0;

    // read and check algorithms offered by server
    if (!ss->t.rekey) printf("(");
    algOK = 1;
    if (algOK) algOK = transport_handle_kexinit_kex(ss, packet, &guessedWrong);
    if (algOK) algOK = transport_handle_kexinit_server_host_key(ss, packet, &guessedWrong);
    if (algOK) algOK = transport_handle_kexinit_cipher(ss, packet);
    if (algOK) algOK = transport_handle_kexinit_mac(ss, packet);
    if (algOK) algOK = transport_handle_kexinit_compression(ss, packet);
    if (algOK) algOK = transport_handle_kexinit_languages(ss, packet);
    if (!algOK) {
        ssh_kill(ss);
        return ss->t.state;
    }
    if (!ss->t.rekey) printf(")\r\n");


    buffer_get(packet, &guess, 1); // first_kex_packet_follows
    buffer_get_int(packet);        // unused reserved
    buffer_require_empty(packet);

    if (guess) {
        printf("transport: remote sent guessed packet\r\n");
        if (guessedWrong) {
            // the next packet used the wrong algorithm
            printf("transport: guessed packet is WRONG\r\n");
            ss->t.ignore_next_packet = 1;
        } else {
            // the next packet is useful
            printf("transport: guessed packet is CORRECT\r\n");
        }
    }

    // choose kex key length based on negotiated ciphers and MACs
    ss->t.kex_key_length = transport_choose_kex_key_length(ss);

    // start diffie-hellman key exchange
    // Derived from part of OpenSSH's kexdh_client()

    if (!ss->t.rekey) printf("   Generating key...\r\n");

    ss->t.dh = dh_new_group1();
    transport_log("new group1 done");
    dh_gen_key(ss->t.dh, ss->t.kex_key_length * 8);
    transport_log("gen key done");
    transport_send_dh_init(ss, ss->t.dh->pub_key);
    transport_log("transport_handle_kexinit done");
    return TRANSPORT_STATE_WAITING_FOR_DH_REPLY;
}


// Derived from part of OpenSSH's kexdh_client()
static transport_state_t 
transport_handle_dh_reply(ssh_session_t *ss, Buffer *packet)
{
    BIGNUM *dh_server_pub = NULL, *shared_secret = NULL;
    DH *dh = ss->t.dh;
    Key *server_host_key;
    uint8_t *server_host_key_blob, *signature;
    uint8_t *kbuf, *hash;
    uint16_t klen, kout, slen, sbloblen;

    /* key, cert */
    server_host_key_blob = buffer_get_string(packet, &sbloblen);
    server_host_key = key_from_blob(server_host_key_blob, sbloblen);
    if (server_host_key == NULL)
        fatal("cannot decode server_host_key_blob");
    if (server_host_key->type != ss->t.hostkey_type)
        fatal("type mismatch for decoded server_host_key_blob");
    if (!check_host_key(ss->s.hostname, ss->s.hostaddr, server_host_key)) {
        printf("   Server host key rejected.\r\n");
        ssh_kill(ss);
        return ss->t.state;
    }

    if (!ss->t.rekey) printf("   Calculating shared secret...\r\n");

    /* DH parameter f, server public DH key */
    dh_server_pub = BN_new();
    buffer_get_bignum2(packet, dh_server_pub);

    /* signed H */
    signature = buffer_get_string(packet, &slen);
    buffer_require_empty(packet);

    if (!dh_pub_is_valid(dh, dh_server_pub)) 
        fatal("bad server public DH value");

    klen = DH_size(dh);
    kbuf = arena_malloc(klen);
    kout = DH_compute_key(kbuf, dh_server_pub, dh);
    transport_log("computed shared secret");
    shared_secret = BN_new();
    BN_bin2bn(kbuf, kout, shared_secret);
    memset(kbuf, 0, klen);
    arena_free(kbuf);

    /* calc and verify H */
    hash = kex_dh_hash(
                       ss->p.local_version_string, 
                       ss->p.remote_version_string, 
                       &ss->t.local_kexinit, 
                       &ss->t.remote_kexinit, 
                       server_host_key_blob, sbloblen, 
                       dh->pub_key, 
                       dh_server_pub, 
                       shared_secret
                       );
    arena_free(server_host_key_blob);
    BN_clear_free(dh_server_pub);
    DH_free(dh); ss->t.dh = dh = NULL;

    if (key_verify(server_host_key, signature, slen, hash, 20) != 1)
        fatal("key_verify failed for server_host_key");
    key_free(server_host_key);

    arena_free(signature);
    
    /* save session id (which does not get replaced during rekey) */
    if (ss->t.session_id == NULL) {
        ss->t.session_id_len = SHA_DIGEST_LENGTH;
        ss->t.session_id = arena_malloc(SHA_DIGEST_LENGTH);
        memcpy(ss->t.session_id, hash, SHA_DIGEST_LENGTH);
    }

    transport_log("deriving keys");
    kex_derive_keys(&ss->t.newInKeys, &ss->t.newOutKeys, ss->t.kex_key_length, 
                    hash, shared_secret, ss->t.session_id, SHA_DIGEST_LENGTH);
    transport_log("derived keys");
    BN_clear_free(shared_secret);
    arena_free(hash);

    // send NEWKEYS and wait for remote's NEWKEYS
    transport_send_newkeys(ss);
    transport_log("sent newkeys");
    return TRANSPORT_STATE_WAITING_FOR_NEWKEYS;
}


static transport_state_t 
transport_handle_newkeys(ssh_session_t *ss, Buffer *packet)
{
    buffer_require_empty(packet);
    transport_log("changing crypto");
    packetizer_change_crypto(ss, &ss->t.newInKeys, &ss->t.newOutKeys);
    transport_log("changed crypto");
    connection_start(ss); // ok if it's already been started
    transport_log("connection started");
    ss->t.rekey = true; // first kex definitely complete
    return TRANSPORT_STATE_RUNNING;
}


static transport_state_t 
transport_handle_re_kexinit(ssh_session_t *ss, Buffer *packet)
{
    // server requested rekey (we never request it)
    transport_send_kexinit(ss);
    return transport_handle_kexinit(ss, packet);
}


static transport_state_t 
transport_state_waiting_for_kexinit(ssh_session_t *ss, 
                                    uint8_t msg, Buffer *packet) 
{
    if (msg == SSH2_MSG_KEXINIT) {
        return transport_handle_kexinit(ss, packet);
    } else {
        transport_send_unimplemented(ss, ss->p.seqnoIn);
        return ss->t.state;
    }
}


static transport_state_t 
transport_state_waiting_for_dh_reply(ssh_session_t *ss, 
                                     uint8_t msg, Buffer *packet) 
{
    if (msg == SSH2_MSG_KEXDH_REPLY) {
        return transport_handle_dh_reply(ss, packet);
    } else {
        transport_send_unimplemented(ss, ss->p.seqnoIn);
        return ss->t.state;
    }
}

static transport_state_t 
transport_state_waiting_for_newkeys(ssh_session_t *ss, 
                                    uint8_t msg, Buffer *packet) 
{
    if (msg == SSH2_MSG_NEWKEYS) {
        return transport_handle_newkeys(ss, packet);
    } else {
        // MUST NOT accept anything other than anytime and newkeys messages
        ssh_kill(ss);
        return ss->t.state;
    }
}

static transport_state_t 
transport_state_running(ssh_session_t *ss, 
                        uint8_t msg, Buffer *packet) 
{
    if (msg == SSH2_MSG_KEXINIT) {
        return transport_handle_re_kexinit(ss, packet);
    } 
    else if ((msg >= SSH2_MSG_CONNECTION_MIN  &&  
              msg <= SSH2_MSG_CONNECTION_MAX)  ||  
             (msg >= SSH2_MSG_USERAUTH_MIN  &&  
              msg <= SSH2_MSG_USERAUTH_MAX)  ||
             msg == SSH2_MSG_SERVICE_ACCEPT) 
    {
        // fixme SERVICE_ACCEPT should be moved from connection to transport
        connection_receive_packet(ss, msg, packet);
        // NOTE state may have changed during connection_handle_packet !
        return ss->t.state;
    }  
    else {
        transport_send_unimplemented(ss, ss->p.seqnoIn);
        return ss->t.state;
    }
}


void transport_receive_packet(ssh_session_t *ss, Buffer *packet)
{
    uint8_t msg;

    if (ss->t.ignore_next_packet) {
        // Throw this packet away (bad algorithm guess during kexinit)
        ss->t.ignore_next_packet = 0;
        return;
    }

    buffer_get(packet, &msg, 1);

    transport_log("transport (state %d): got %d\r\n", ss->t.state, msg);

    switch (msg) {
    case SSH2_MSG_IGNORE:
        ss->t.state = transport_handle_ignore(ss, packet);
        break;
    case SSH2_MSG_DEBUG:
        ss->t.state = transport_handle_debug(ss, packet);
        break;
    case SSH2_MSG_DISCONNECT:
        ss->t.state = transport_handle_disconnect(ss, packet);
        break;
    case SSH2_MSG_UNIMPLEMENTED:
        ss->t.state = transport_handle_unimplemented(ss, packet);
        break;

    default: 
        switch (ss->t.state) {
        case TRANSPORT_STATE_WAITING_FOR_KEXINIT:
            ss->t.state = transport_state_waiting_for_kexinit(ss, msg, packet);
            break;
        case TRANSPORT_STATE_WAITING_FOR_DH_REPLY:
            ss->t.state = transport_state_waiting_for_dh_reply(ss, msg,packet);
            break;
        case TRANSPORT_STATE_WAITING_FOR_NEWKEYS:
            ss->t.state = transport_state_waiting_for_newkeys(ss, msg, packet);
            break;
        case TRANSPORT_STATE_RUNNING:
            ss->t.state = transport_state_running(ss, msg, packet);
            break;
        case TRANSPORT_STATE_CLOSED:
            printf("transport: packet (type %d) sent to closed state (%d %d %d %d)\r\n", msg, ss->s.state, ss->p.state, ss->t.state, ss->c.state);
            ssh_kill(ss);
            break;
        default:
            printf("transport: BUSTED\r\n");
            ssh_kill(ss);
            break;
        }
        break;
    }

    // Send queued packets (if any) when RUNNING
    if (ss->t.state == TRANSPORT_STATE_RUNNING) {
        transport_send_queued_payloads(ss);
    }
}


void transport_start(ssh_session_t *ss)
{
    if (ss->t.state == TRANSPORT_STATE_STARTING) {
        transport_set_supported_algorithms(ss);

        printf("   Negotiating algorithms... ");
        transport_send_kexinit(ss);
        ss->t.state = TRANSPORT_STATE_WAITING_FOR_KEXINIT;
    }
}


static void transport_cleanup(ssh_session_t *ss)
{
    if (ss->t.dh) DH_free(ss->t.dh);
    if (ss->t.session_id) arena_free(ss->t.session_id); 
    if (ss->t.packet_queue) arena_free(ss->t.packet_queue);
    buffer_free(&ss->t.remote_kexinit);
    buffer_free(&ss->t.local_kexinit);
    keys_free(&ss->t.newInKeys);
    keys_free(&ss->t.newOutKeys);
    transport_free_supported_algorithms(ss);
    
    memset(&ss->t, 0, sizeof(ss->t));
    ss->t.state = TRANSPORT_STATE_CLOSED;
}


void transport_kill(ssh_session_t *ss)
{
    if (ss->t.state != TRANSPORT_STATE_CLOSED) {
        transport_cleanup(ss);
    }
}


void transport_close(ssh_session_t *ss)
{
    if (ss->t.state != TRANSPORT_STATE_CLOSED  &&  
        ss->t.state != TRANSPORT_STATE_STARTING) 
    {
        transport_send_disconnect(ss, 0, "");
        transport_kill(ss);
    }
    ss->t.state = TRANSPORT_STATE_CLOSED;
}


int transport_is_open(ssh_session_t *ss) {
    return (ss->t.state != TRANSPORT_STATE_STARTING  &&  
            ss->t.state != TRANSPORT_STATE_CLOSED);
}


int transport_is_closing(ssh_session_t *ss) {
    // transport is never closing
    return 0;
}


// Used by connection layer and vt100
// Queues packets for later transmission if transport is not RUNNING
// (i.e. before startup and during kex and rekey)
// Discards packets send when transport is CLOSED.
void transport_send_payload(ssh_session_t *ss, Buffer *payload)
{
    if (ss->t.state == TRANSPORT_STATE_CLOSED) {
        // throw it away
        return;
    }
    else if (ss->t.state == TRANSPORT_STATE_RUNNING) {
        // send immediately
        packetizer_send_payload(ss, payload);
    }
    else {
        // copy payload and send later
        Buffer *b;
        if (ss->t.packet_queue_used == ss->t.packet_queue_allocated) {
            ss->t.packet_queue_allocated = ss->t.packet_queue_allocated*2 + 1;
            ss->t.packet_queue = arena_realloc(ss->t.packet_queue, 
                                               ss->t.packet_queue_allocated * 
                                               sizeof(Buffer));
        }
        b = &ss->t.packet_queue[ss->t.packet_queue_used++];
        buffer_init(b);
        buffer_append(b, buffer_ptr(payload), buffer_len(payload));
    }
}


static void transport_send_queued_payloads(struct ssh_session_t *ss)
{
    if (ss->t.packet_queue_used > 0) {
        int i;
        for (i = 0; i < ss->t.packet_queue_used; i++) {
            Buffer *b = &ss->t.packet_queue[i];
            packetizer_send_payload(ss, b);
            buffer_free(b);
        }
        ss->t.packet_queue_used = 0;
    }
}
