/**********
 * Copyright (c) 2003-2004 Greg Parker.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY GREG PARKER ``AS IS'' AND ANY EXPRESS OR
 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 **********/

#include "includes.h"
#include "recordlist.h"
#include "rsrc/rsrc.h"
#include "ssh/openssh/key.h"
#include "ssh/openssh/match.h"

#include "hostkeys.h"


/* host key record format:

   2 bytes length
   n bytes hostname,hostname,hostname\0
   2 bytes length
   n bytes key blob
*/


#define HostKeyDBName "pssh Known Host Keys"
#define HostKeyDBType 'HKey'
static DmOpenRef HostKeyDB = 0;
static RecordList *HostKeyList = NULL;

static Boolean ReadHostKeyRecord(uint8_t *recordP, char **hostnames, uint8_t **keyblob, uint16_t *keybloblen) HOSTKEYS_SEGMENT;
static Boolean WriteHostKeyRecord(MemPtr recordP, const char *hostnames, uint8_t *keyblob, uint16_t keybloblen) HOSTKEYS_SEGMENT;
static void DrawHostKeyRecord(MemPtr recordP, UInt16 index, RectanglePtr bounds) HOSTKEYS_SEGMENT;

extern DmOpenRef OpenDB(UInt32 type, char *name, Boolean resDB, Boolean create);

Boolean HostKeysInit(void)
{
    HostKeyDB = OpenDB(HostKeyDBType, HostKeyDBName, false, true);
    if (!HostKeyDB) return false;

    HostKeyList = 
        RecordListNew(HostKeyDB, HostKeysFormID, HostKeysFormKeyTableID, 
                      HostKeysFormKeyScrollbarID, DrawHostKeyRecord);
    if (!HostKeyList) return false;

    return true;
}


void HostKeysFree(void)
{
    RecordListFree(HostKeyList);
    DmCloseDatabase(HostKeyDB);
}


void HostKeysUpdate(void)
{
    return RecordListUpdate(HostKeyList);
}


Boolean HostKeysHandleEvent(EventPtr event)
{
    return RecordListHandleEvent(HostKeyList, event);
}


UInt16 HostKeysSelectedIndex(void)
{
    return RecordListSelectedIndex(HostKeyList);
}


void HostKeysDeleteSelectedRecord(void)
{
    return RecordListDeleteSelectedRecord(HostKeyList);
}


static Boolean ReadHostKeyRecord(uint8_t *recordP, char **hostnames, 
                                 uint8_t **keyblob, uint16_t *keybloblen)
{
#define CHECK_SPACE(n) do { if (p+(n)>end) goto bad; } while (0)

    uint8_t *p;
    uint8_t *end;
    uint16_t len;

    p = recordP;
    end = recordP + MemPtrSize(recordP);

    // read hostnames string

    CHECK_SPACE(2);
    len = *(uint16_t *)p;
    p += 2;
    CHECK_SPACE(len);
    *hostnames = (char *)p;
    if ((*hostnames)[len-1] != '\0') goto bad;
    p += len;

    // read key blob
    
    CHECK_SPACE(2);
    *keybloblen = *(uint16_t *)p;
    p += 2;
    CHECK_SPACE(*keybloblen);
    *keyblob = p;

    p += *keybloblen;

    // allow trailing data for forward compatibility

    return true;

 bad:
    return false;

#undef CHECK_SPACE
}


static Boolean WriteHostKeyRecord(MemPtr recordP, const char *hostnames, 
                                  uint8_t *keyblob, uint16_t keybloblen)
{
    Err err = 0;
    uint32_t offset = 0;
    uint16_t hostnameslen = strlen(hostnames) + 1;

    if (!err) err = DmWrite(recordP, offset, &hostnameslen, 2);
    offset += 2;
    if (!err) err = DmWrite(recordP, offset, hostnames, hostnameslen);
    offset += hostnameslen;
    if (!err) err = DmWrite(recordP, offset, &keybloblen, 2);
    offset += 2;
    if (!err) err = DmWrite(recordP, offset, keyblob, keybloblen);

    return (err == 0);
}



MemHandle HostKeysQuerySelectedRecord(char **hostnames, 
                                      uint8_t **keyblob, uint16_t *keybloblen)
{
    return HostKeysQueryIndexedRecord(RecordListSelectedIndex(HostKeyList), 
                                      hostnames, keyblob, keybloblen);
}


MemHandle HostKeysQueryIndexedRecord(UInt16 index, char **hostnames, 
                                     uint8_t **keyblob, uint16_t *keybloblen)
{
    MemHandle recordH;
    MemPtr recordP;
    Boolean ok;

    recordH = RecordListQueryIndexedRecord(HostKeyList, index);
    if (!recordH) return NULL;

    recordP = MemHandleLock(recordH);
    ok = ReadHostKeyRecord(recordP, hostnames, keyblob, keybloblen);
    
    if (!ok) {
        MemHandleUnlock(recordH);
        return NULL;
    } else {
        return recordH;
    }
}


UInt16 HostKeysFindRecordForHostname(const char *hostname)
{
    UInt16 count = RecordListCount(HostKeyList);
    UInt16 index;
    MemHandle recordH;
    char *savedhosts;
    uint8_t *keyblob;
    uint16_t keybloblen;

    for (index = 0; index < count; index++) {
        if ((recordH = HostKeysQueryIndexedRecord(index, &savedhosts, 
                                                  &keyblob, &keybloblen)))
        {
            char *host = match_list(hostname, savedhosts, NULL);
            MemHandleUnlock(recordH);
            if (host) {
                xfree(host);
                return index;
            }
        }
    }

    return noRecord;
}


UInt16 HostKeysFindRecordForKey(Key *hostkey)
{
    UInt16 count = RecordListCount(HostKeyList);
    UInt16 index;
    MemHandle recordH;
    char *savedhosts;
    uint8_t *keyblob;
    uint16_t keybloblen;

    for (index = 0; index < count; index++) {
        if ((recordH = HostKeysQueryIndexedRecord(index, &savedhosts, 
                                                  &keyblob, &keybloblen)))
        {
            Key *savedkey = key_from_blob(keyblob, keybloblen);
            Boolean match = key_equal(savedkey, hostkey);
            key_free(savedkey);
            MemHandleUnlock(recordH);
            
            if (match) return index;
        }
    }

    return noRecord;
}


Boolean HostKeysAddRecord(const char *hostname, Key *hostkey)
{
    uint16_t keybloblen = 1;
    uint32_t recordlen;
    MemHandle recordH;
    MemPtr recordP;
    uint8_t *keyblob;
    Boolean ok;

    key_to_blob(hostkey, &keyblob, &keybloblen);

    recordlen = 2L + strlen(hostname) + 1 + 2 + keybloblen;
    RecordListClearSelection(HostKeyList);
    recordH = RecordListGetSelectedRecord(HostKeyList, recordlen);
    if (!recordH) {
        xfree(keyblob);
        return false;
    }

    recordP = MemHandleLock(recordH);

    ok = WriteHostKeyRecord(recordP, hostname, keyblob, keybloblen);

    // fixme delete new record on failure
    xfree(keyblob);
    MemHandleUnlock(recordH);
    RecordListReleaseRecord(HostKeyList, recordH, true);

    return ok;
}


Boolean HostKeysAddHostnameToRecord(const char *hostname, UInt16 index)
{
    MemHandle recordH;
    MemPtr recordP;
    uint32_t recordlen;
    char *hostnames;
    uint8_t *keyblob;
    uint16_t keybloblen;
    Boolean result = false;
    Boolean ok;
    char *newhostnames;
    uint8_t *newkeyblob;

    RecordListClearSelection(HostKeyList);
    RecordListSetSelectedIndex(HostKeyList, index);

    recordH = RecordListQuerySelectedRecord(HostKeyList);
    if (!recordH) return false;

    // new length includes comma and new hostname
    // (assumes record contains at least 1 hostname, and that 
    // the new hostname isn't there already)
    recordlen = MemHandleSize(recordH) + strlen(hostname) + strlen(",");

    recordH = DmResizeRecord(HostKeyDB, index, recordlen);
    if (!recordH) return false;
    recordP = MemHandleLock(recordH);

    ok = ReadHostKeyRecord(recordP, &hostnames, &keyblob, &keybloblen);
    if (!ok) goto bad;
    
    newhostnames = arena_malloc(strlen(hostnames) + strlen(",") + strlen(hostname) + 1);
    strcpy(newhostnames, hostnames);
    strcat(newhostnames, ",");
    strcat(newhostnames, hostname);

    newkeyblob = arena_malloc(keybloblen);
    memcpy(newkeyblob, keyblob, keybloblen);
    
    ok = WriteHostKeyRecord(recordP, newhostnames, newkeyblob, keybloblen);
    arena_free(newhostnames);
    arena_free(newkeyblob);
    if (!ok) goto bad;

    result = true;

 bad:
    // fixme destroy record on failure
    MemHandleUnlock(recordH);
    RecordListReleaseRecord(HostKeyList, recordH, true);
    return result;
}



Boolean HostKeysRemoveHostnameFromRecord(const char *hostname, UInt16 index)
{
    MemHandle recordH;
    MemPtr recordP;
    uint32_t recordlen;
    char *hostnames;
    uint8_t *keyblob;
    uint16_t keybloblen;
    Boolean ok;
    char *newhostnames;
    uint8_t *newkeyblob;
    char *start, *end;

    RecordListClearSelection(HostKeyList);
    RecordListSetSelectedIndex(HostKeyList, index);

    recordH = RecordListQuerySelectedRecord(HostKeyList);
    if (!recordH) return false;
    recordP = MemHandleLock(recordH);

    ok = ReadHostKeyRecord(recordP, &hostnames, &keyblob, &keybloblen);
    if (!ok) {
        MemHandleUnlock(recordH);
        return false;
    }

    // If this is the only hostname in the record - kill it completely
    if (0 == strcasecmp(hostnames, hostname)) {
        MemHandleUnlock(recordH);
        RecordListDeleteSelectedRecord(HostKeyList);
        return true;
    }

    // Make a copy of newhostnames that does not include hostname
    newhostnames = arena_strdup(hostnames);
    start = end = NULL;
    if (!start) {
        // try ...,hostname,...
        char *commahostnamecomma = arena_malloc(1 + strlen(hostname) + 1 + 1);
        strcpy(commahostnamecomma, ",");
        strcat(commahostnamecomma, hostname);
        strcat(commahostnamecomma, ",");
        start = strcasestr(newhostnames, commahostnamecomma);
        if (start) end = start + strlen(commahostnamecomma);
        arena_free(commahostnamecomma);
    }
    if (!start) {
        // try hostname,...
        char *hostnamecomma = arena_malloc(strlen(hostname) + 1 + 1);
        strcpy(hostnamecomma, hostname);
        strcat(hostnamecomma, ",");
        if (0 == strncmp(newhostnames, hostnamecomma, strlen(hostnamecomma))) {
            start = newhostnames;
            end = start + strlen(hostnamecomma);
        }
        arena_free(hostnamecomma);
    }
    if (!start) {
        // try ...,hostname
        char *commahostname = arena_malloc(1 + strlen(hostname) + 1);
        strcpy(commahostname, ",");
        strcat(commahostname, hostname);
        start = strrchr(newhostnames, ',');
        if (start  &&  0 == strcmp(start, commahostname)) {
            end = start + strlen(commahostname);
        }
        arena_free(commahostname);
    }
    if (!start) {
        // didn't find hostname in hostname list
        MemHandleUnlock(recordH);
        arena_free(newhostnames);
        return false;
    }

    // kill [start..end]
    memmove(start, end, strlen(end)+1);
    
    newkeyblob = arena_malloc(keybloblen);
    memcpy(newkeyblob, keyblob, keybloblen);
    
    recordlen = 2L + strlen(newhostnames) + 1 + 2 + keybloblen;
    
    // Reopen the record for writing and resize
    MemHandleUnlock(recordH);
    recordH = DmResizeRecord(HostKeyDB, RecordListSelectedIndex(HostKeyList), 
                             recordlen);
    if (!recordH) {
        arena_free(newhostnames);
        arena_free(newkeyblob);
        return false;
    }

    // Write the new record data
    recordP = MemHandleLock(recordH);
    ok = WriteHostKeyRecord(recordP, newhostnames, newkeyblob, keybloblen);
    MemHandleUnlock(recordH);
    RecordListReleaseRecord(HostKeyList, recordH, true);
    arena_free(newhostnames);
    arena_free(newkeyblob);
    if (!ok) return false;

    return true;
}


static void DrawHostKeyRecord(MemPtr recordP, UInt16 index, 
                              RectanglePtr bounds)
{
    char *hostnames;
    uint8_t *keyblob;
    uint16_t keybloblen;

    if (ReadHostKeyRecord(recordP, &hostnames, 
                          &keyblob, &keybloblen))
    {
        // "hostname,hostname,hostname"
        int len;
        int x = bounds->topLeft.x + 1;
        int y = bounds->topLeft.y;

        len = StrLen(hostnames);
        WinDrawTruncChars(hostnames, len, x, y, 
                          bounds->topLeft.x + bounds->extent.x - x - 1);
        x += FntCharsWidth(hostnames, len);
    }
}
