/**********
 * Copyright (c) 2003-2005 Greg Parker.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY GREG PARKER ``AS IS'' AND ANY EXPRESS OR
 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 **********/

#include "includes.h"
#include "packetizer.h"

#include "session.h"
#include "transport.h"
#include "connection.h"
#include "ssh2.h"
#include "openssh/buffer.h"
#include "openssh/bufaux.h"
#include "openssh/cipher.h"
#include "openssh/mac.h"
#include "crypto/rand.h"


static char *chop(char *s) PACKETIZER_SEGMENT;
static void packetizer_decrypt(ssh_session_t *ss, uint8_t *dst, uint8_t *src, uint16_t length) PACKETIZER_SEGMENT;
static packetizer_state_t packetizer_handle_version_start(ssh_session_t *ss, Buffer *b) PACKETIZER_SEGMENT;
static packetizer_state_t packetizer_handle_not_version(ssh_session_t *ss, Buffer *b) PACKETIZER_SEGMENT;
static packetizer_state_t packetizer_handle_version(ssh_session_t *ss, Buffer *b) PACKETIZER_SEGMENT;
static packetizer_state_t packetizer_handle_header(ssh_session_t *ss, Buffer *b) PACKETIZER_SEGMENT;
static packetizer_state_t packetizer_handle_body(ssh_session_t *ss, Buffer *b) PACKETIZER_SEGMENT;
static void packetizer_cleanup(ssh_session_t *ss) PACKETIZER_SEGMENT;



/* remove newline at end of string */
static char *chop(char *s) 
{
    char *t = s;
    while (*t) {
        if (*t == '\n' || *t == '\r') {
            *t = '\0';
            return s;
        }
        t++;
    }
    return s;
}


static void packetizer_decrypt(ssh_session_t *ss, uint8_t *dst, uint8_t *src, 
                               uint16_t length)
{
    if (!ss->p.inKeys.enc.enabled) {
        // no crypto yet
        memcpy(dst, src, length);
    } else {
        cipher_crypt(&ss->p.inContext, dst, src, length);
    }
}

extern Boolean RefreshNetLib(void); // fixme put this in a header file

void packetizer_send_payload(ssh_session_t *ss, Buffer *payload)
{
    // fixme what about blocking write?
    // fixme implement write queue
    void *p;
    uint16_t packet_length;
    uint16_t padded_length;
    size_t written;
    uint8_t padding_length;
    static Buffer packet = {0, 0, 0, 0};
    uint8_t *macbuf = NULL;
    Keys *keys = &ss->p.outKeys;
    Buffer comp_buffer = {0, 0, 0, 0};

    if (!packet.buf) buffer_init(&packet);
    buffer_clear(&packet);

    // compress payload before calculating lengths
    if (keys->comp) {
        unsigned char *newPayload;
        size_t newLen;
        if (keys->comp->transcode(keys->comp_ctx, 
                                  buffer_ptr(payload), buffer_len(payload), 
                                  &newPayload, &newLen)) 
        {
            comp_buffer.buf = newPayload;
            comp_buffer.alloc = newLen;
            comp_buffer.offset = 0;
            comp_buffer.end = newLen;
            payload = &comp_buffer;
        }
    }

    padding_length = 4; // 4-byte minimum random padding
    // length of everything but MAC must be multiple of max(8, block size)
    padded_length = 4 + 1 + buffer_len(payload) + padding_length;
    padding_length += ss->p.outHeaderLength - 
        (padded_length % ss->p.outHeaderLength);

    // packet_length field doesn't include itself
    packet_length = 1 + buffer_len(payload) + padding_length;

    // packet_length
    buffer_put_int(&packet, packet_length);
    
    // padding_length
    buffer_put_char(&packet, padding_length);

    // payload (maybe already compressed)
    buffer_append(&packet, buffer_ptr(payload), buffer_len(payload));

    // random padding
    p = buffer_append_space(&packet, padding_length);
    RAND_bytes(p, padding_length);
    
    // compute MAC across unencrypted bytes
    if (keys->mac.enabled) {
        macbuf = mac_compute(&keys->mac, ss->p.seqnoOut, 
                             buffer_ptr(&packet), buffer_len(&packet));
    }

    // encrypt
    if (ss->p.outKeys.enc.enabled) {
        Buffer encrypted;
        buffer_init(&encrypted);
        buffer_append_space(&encrypted, buffer_len(&packet));
        cipher_crypt(&ss->p.outContext, buffer_ptr(&encrypted), 
                     buffer_ptr(&packet), buffer_len(&packet));
        buffer_free(&packet);
        packet = encrypted;
    }

    // append MAC
    if (keys->mac.enabled) {
        buffer_append(&packet, macbuf, keys->mac.mac_len);
        arena_free(macbuf);
    }

    packetizer_log("packetizer: sending packet (%d bytes: 4+1+%d+%d+mac)\r\n", 
           buffer_len(&packet), buffer_len(payload), padding_length);

    if (comp_buffer.buf) buffer_free(&comp_buffer);

    // fixme implement write queue

    // Re-open network connection (may have been disconnected after 
    // auto power-off, for example)
    // fixme call RefreshNetLib less often
    if (ss->closing) {
        // do not refresh network while closing
    } else if (!RefreshNetLib()) {
        ssh_kill(ss);
        return;
    }

    // fixme this kill sometimes crashes?
    written = write(ss->s.socket, buffer_ptr(&packet), buffer_len(&packet));
    if (written != buffer_len(&packet)) {
        ssh_kill(ss);
        return;
    }

    ss->p.seqnoOut++;
}


void packetizer_change_crypto(ssh_session_t *ss, 
                              Keys *newInKeys, Keys *newOutKeys)
{
    Enc *inEnc = &ss->p.inKeys.enc;
    Mac *inMac = &ss->p.inKeys.mac;
    Comp *inComp = ss->p.inKeys.comp;
    Enc *outEnc = &ss->p.outKeys.enc;
    Mac *outMac = &ss->p.outKeys.mac;
    Comp *outComp = ss->p.outKeys.comp;

    // cleanup old keys
    keys_free(&ss->p.inKeys);
    keys_free(&ss->p.outKeys);
    cipher_cleanup(&ss->p.inContext);
    cipher_cleanup(&ss->p.outContext);

    ss->p.inKeys = *newInKeys;
    ss->p.outKeys = *newOutKeys;
    ss->p.inHeaderLength = MAX(8, inEnc->block_size);
    ss->p.outHeaderLength = MAX(8, outEnc->block_size);

    memset(newInKeys, 0, sizeof(*newInKeys));
    memset(newOutKeys, 0, sizeof(*newOutKeys));

    // mac
    if (inMac->md) inMac->enabled = 1;
    if (outMac->md) outMac->enabled = 1;

    // cipher
    if (inEnc->cipher) inEnc->enabled = 1;
    if (outEnc->cipher) outEnc->enabled = 1;

    // cipher context
    
    cipher_init(&ss->p.inContext, inEnc->cipher, inEnc->key, inEnc->key_len, 
                inEnc->iv, inEnc->block_size, CIPHER_DECRYPT);
    cipher_init(&ss->p.outContext, outEnc->cipher, outEnc->key,outEnc->key_len,
                outEnc->iv, outEnc->block_size, CIPHER_ENCRYPT);

    // compression
    if (outComp) {
        ss->p.outKeys.comp_ctx = outComp->init();
    }
    if (inComp) {
        ss->p.inKeys.comp_ctx = inComp->init();
    }
}


static packetizer_state_t 
packetizer_handle_version_start(ssh_session_t *ss, Buffer *b)
{
    uint8_t *bufData = buffer_ptr(b);
    uint16_t bufLen = buffer_len(b);

    packetizer_log("waiting-for-version %s\r\n", bufData);

    if (bufLen >= 4  &&  0 == strncmp(bufData, "SSH-", 4)) {
        // leave "SSH-" in buffer for eventual version string
        return PACKETIZER_STATE_READING_VERSION;
    } else {
        // leave data in buffer to act as overflow check
        return PACKETIZER_STATE_READING_NOT_VERSION;
    }
}                           



static packetizer_state_t 
packetizer_handle_not_version(ssh_session_t *ss, Buffer *b)
{
    uint8_t *bufData = buffer_ptr(b);
    uint16_t bufLen = buffer_len(b);
    uint8_t *pos;

    packetizer_log("reading-not-version\r\n");
    pos = memchr(bufData, '\n', bufLen);
    if (pos) {
        buffer_consume(b, pos - bufData + 1); // up to and including \n
        return PACKETIZER_STATE_WAITING_FOR_VERSION;
    } else {
        // leave data in buffer to act as overflow check
        return ss->p.state;
    }
}


static packetizer_state_t 
packetizer_handle_version(ssh_session_t *ss, Buffer *b)
{
    uint8_t *bufData = buffer_ptr(b);
    uint16_t bufLen = buffer_len(b);
    uint8_t *pos;

    packetizer_log("reading-version\r\n");
    pos = memchr(bufData, '\n', bufLen);
    if (pos) {
        // got version
        size_t len = pos - bufData + 1; // up to and including \n
        ss->p.remote_version_string = arena_malloc(len + 1);
        buffer_get(b, ss->p.remote_version_string, len);
        ss->p.remote_version_string[len] = '\0';
        chop(ss->p.remote_version_string);

        packetizer_log("got remote version %s\r\n", 
                       ss->p.remote_version_string);

        // Check version
        // We accept "SSH-2.0-..." or "SSH-1.99-..."
        if (0 != strcmp(ss->p.remote_version_string, "SSH-2.0")  &&  
            0 != strcmp(ss->p.remote_version_string, "SSH-1.99")  &&  
            0 != strncmp(ss->p.remote_version_string, "SSH-2.0-", 8)  &&  
            0 != strncmp(ss->p.remote_version_string, "SSH-1.99-", 9))
        {
            // bogus version
            printf("unknown SSH protocol version '%s'\r\n", 
                   ss->p.remote_version_string);
        }
        
        transport_start(ss);
        
        return PACKETIZER_STATE_READING_PACKET_HEADER;
    } else {
        // leave data in buffer for eventual version string
        return ss->p.state;
    }
}


static packetizer_state_t 
packetizer_handle_header(ssh_session_t *ss, Buffer *b)
{
    uint8_t *bufData = buffer_ptr(b);
    uint16_t bufLen = buffer_len(b);

    packetizer_log("reading-packet-header\r\n");
    if (bufLen < ss->p.inHeaderLength) {
        // wait for more data to decrypt packet size
        return ss->p.state;
    } else {
        uint8_t *decrypt_dst = 
            buffer_append_space(&ss->p.decryptedData, ss->p.inHeaderLength);
        packetizer_decrypt(ss, decrypt_dst, bufData, ss->p.inHeaderLength);
        buffer_consume(b, ss->p.inHeaderLength);
        
        ss->p.packetLength = *(uint32_t *)decrypt_dst;
        ss->p.paddingLength = *(uint8_t *)(decrypt_dst + 4);

        // max packet length specified by SSH protocol spec
        if (ss->p.packetLength > 35000) {
            printf("\r\npacketizer: packet too big (%lu bytes)\r\n", ss->p.packetLength);
            ssh_kill(ss);
            return ss->p.state;
        }
        
        packetizer_log("packet length %ld padding %d\r\n", 
                       ss->p.packetLength, ss->p.paddingLength);
        // don't consume any decrypted data - it's need for MAC check
        
        // unhandled data is beginning of payload
        return PACKETIZER_STATE_READING_PACKET_BODY;
    }
}


static packetizer_state_t 
packetizer_handle_body(ssh_session_t *ss, Buffer *b)
{
    Mac *mac = &ss->p.inKeys.mac;
    Comp *comp = ss->p.inKeys.comp;
    void *comp_ctx = ss->p.inKeys.comp_ctx;
    Buffer comp_buffer = {0, 0, 0, 0};

    uint16_t encryptedLen = ss->p.packetLength + 4 - ss->p.inHeaderLength;
    // +4 for packet_length field, -inHeaderLength for already-decrypted block
    // encryptedLen includes everything still encrypted, but not the MAC

    packetizer_log("reading-packet-body\r\n");

    if (buffer_len(b) < encryptedLen + mac->mac_len) {
        // wait for the rest of the packet, including MAC but not including already-decrypted block
        packetizer_log("no body (total %ld, have %d)", ss->p.packetLength + 4, buffer_len(b) + ss->p.inHeaderLength);
        return ss->p.state;
    } else {
        Buffer *d = &ss->p.decryptedData;
        uint8_t *decrypt_dst = buffer_append_space(d, encryptedLen);
        packetizer_decrypt(ss, decrypt_dst, buffer_ptr(b), encryptedLen);
        buffer_consume(b, encryptedLen);
        
        // Entire decrypted packet is now in decryptedData (d)
        // Entire MAC is now in unhandledData (b)

        // Check MAC
        if (mac->enabled) {
            // Compute MAC(sequence number || decrypted length+payload+padding)
            uint8_t *macbuf = mac_compute(mac, ss->p.seqnoIn, 
                                          buffer_ptr(d), buffer_len(d));
            if (memcmp(macbuf, buffer_ptr(b), mac->mac_len) != 0) {
                // MAC is wrong
                printf("\r\npacketizer: bad MAC from server\r\n");
                ssh_kill(ss);
                arena_free(macbuf);
                return ss->p.state;
            }
            buffer_consume(b, mac->mac_len);
            arena_free(macbuf);
        }

        // unpack and uncompress decrypted payload
        buffer_consume(d, 4+1);                     // leading header
        buffer_consume_end(d, ss->p.paddingLength); // trailing random pad
        if (comp) {
            unsigned char *newData;
            size_t newLen;
            if (comp->transcode(comp_ctx, buffer_ptr(d), buffer_len(d), 
                                &newData, &newLen)) 
            {
                comp_buffer.buf = newData;
                comp_buffer.alloc = newLen;
                comp_buffer.offset = 0;
                comp_buffer.end = newLen;
                buffer_clear(d);                
                d = &comp_buffer;
            }
        }

        transport_receive_packet(ss, d);
        buffer_clear(d);
        if (comp_buffer.buf) buffer_free(&comp_buffer);
        
        ss->p.seqnoIn++;

        // NOTE state may have changed during transport or connection layer !
        if (ss->p.state == PACKETIZER_STATE_READING_PACKET_BODY) {
            return PACKETIZER_STATE_READING_PACKET_HEADER;
        } else {
            return ss->p.state;
        }
    }
}


void packetizer_receive_data(ssh_session_t *ss)
{
    Buffer *b = &ss->p.unhandledData;
    packetizer_state_t prevState;

    if (ss->p.state != PACKETIZER_STATE_CLOSED) {
        uint8_t *newData;
        int16_t newLen;
        
        packetizer_log("packetizer read\r\n");
        newData = buffer_append_space(b, 4096);
        newLen = read(ss->s.socket, newData, 4096);
        RAND_add_net_entropy();
        packetizer_log("read %d bytes\r\n", newLen);        

        if (newLen == -1  &&  (errno == EAGAIN  ||  errno == EWOULDBLOCK)) {
            // no data available
            buffer_consume_end(b, 4096);
            return;
        }

        if (newLen == 0) {
            // socket closed
            ssh_kill(ss);
            return;
        }

        if (newLen < 0) {
            printf("packetizer: read failed (return %d, errno %d)\r\n", 
                   newLen, errno);
            ssh_kill(ss);
            return;
        }
        
        buffer_consume_end(b, 4096 - newLen);
    }


    do {
        prevState = ss->p.state;

        switch (ss->p.state) {
        case PACKETIZER_STATE_WAITING_FOR_VERSION:
            ss->p.state = packetizer_handle_version_start(ss, b);
            break;
            
        case PACKETIZER_STATE_READING_NOT_VERSION:
            ss->p.state = packetizer_handle_not_version(ss, b);
            break;
            
        case PACKETIZER_STATE_READING_VERSION:
            ss->p.state = packetizer_handle_version(ss, b);
            break;
            
        case PACKETIZER_STATE_READING_PACKET_HEADER:
            ss->p.state = packetizer_handle_header(ss, b);
            break;
            
        case PACKETIZER_STATE_READING_PACKET_BODY: 
            ss->p.state = packetizer_handle_body(ss, b);
            break;

        case PACKETIZER_STATE_CLOSED:
            ss->p.state = PACKETIZER_STATE_CLOSED;
            break;

        default:
            // busted!
            printf("packetizer: BUSTED!\r\n");
            ssh_kill(ss);
            break;
        }
    } while (ss->p.state != prevState  &&  
             ss->p.state != PACKETIZER_STATE_CLOSED);
}


static void packetizer_cleanup(ssh_session_t *ss)
{
    if (ss->p.local_version_string) arena_free(ss->p.local_version_string);
    if (ss->p.remote_version_string) arena_free(ss->p.remote_version_string);
    buffer_free(&ss->p.unhandledData);
    buffer_free(&ss->p.decryptedData);
    keys_free(&ss->p.inKeys);
    keys_free(&ss->p.outKeys);
    cipher_cleanup(&ss->p.inContext);
    cipher_cleanup(&ss->p.outContext);

    memset(&ss->p, 0, sizeof(ss->p));
    ss->p.state = PACKETIZER_STATE_CLOSED;
}


void packetizer_kill(ssh_session_t *ss)
{
    if (ss->p.state != PACKETIZER_STATE_CLOSED) {
        packetizer_cleanup(ss);
    }
}


void packetizer_close(ssh_session_t *ss)
{
    // clean shutdown is same as kill
    packetizer_kill(ss);
}


int packetizer_is_open(ssh_session_t *ss) {
    return (ss->p.state != PACKETIZER_STATE_STARTING  &&  
            ss->p.state != PACKETIZER_STATE_CLOSED);
}


int packetizer_is_closing(ssh_session_t *ss) {
    // packetizer is never closing
    return 0;
}

void packetizer_start(ssh_session_t *ss, char *local_version_string)
{
    if (ss->p.state == PACKETIZER_STATE_STARTING) {
        Cipher *none;

        printf("Starting SSHv2 session\r\n");
        printf("   Sending version...\r\n");

        none = cipher_by_name("none");

        buffer_init(&ss->p.unhandledData);
        buffer_init(&ss->p.decryptedData);
        
        ss->p.inHeaderLength = 8;
        ss->p.outHeaderLength = 8;
        
        ss->p.local_version_string = arena_strdup(local_version_string);
        
        cipher_init(&ss->p.inContext,  none, "", 0, NULL, 0, CIPHER_DECRYPT);
        cipher_init(&ss->p.outContext, none, "", 0, NULL, 0, CIPHER_ENCRYPT);

        {
            // WARNING WARNING WARNING
            // The client version and line terminator MUST be sent in 
            // a single packet. The SSH server SSH-1.99-Cisco-1.25 will 
            // kill the connection if the version and line terminator 
            // are in separate TCP packets.
            // The code below uses a single write() to send a single packet.
            char *eoln = "\r\n";
            int len = strlen(local_version_string) + strlen(eoln);
            char *buf = arena_malloc(len + 1);
            strcpy(buf, local_version_string);
            strcat(buf, eoln);
            write(ss->s.socket, buf, len);
            arena_free(buf);
        }

        ss->p.state = PACKETIZER_STATE_WAITING_FOR_VERSION;
    }
}
