homecontroller/netCode.go

232 lines
5.7 KiB
Go

package main
import (
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"errors"
"fmt"
"hash/crc32"
"net"
"time"
)
type Device struct {
cipher *cipher.Block
conn *net.TCPConn
device *Device
name string
id uint32
typeID uint8
channels []uint8
client *Device
lastHeardFrom int64
privileges uint8
}
type StoredClient struct {
Key []byte `json:"key"`
Name string `json:"name"`
ID uint32 `json:"id"`
Type uint8 `json:"type"`
Privileges uint8 `json:"privileges"`
}
// sendPacket sends a packet with the given packetID and data over the provided TCP connection.
func sendPacket(packetID uint8, data []byte, client *Device) error {
// Create a buffer with the packet ID and data
tempBuffer := append([]byte{packetID}, data...)
// Calculate the checksum of the packet ID + data
checksum := crc32.ChecksumIEEE(tempBuffer)
// Convert the checksum to a 4-byte slice
checksumBytes := make([]byte, 4)
binary.LittleEndian.PutUint32(checksumBytes, checksum)
var encryptedTempBuffer []byte
if client.cipher != nil {
(*client.cipher).Encrypt(encryptedTempBuffer, tempBuffer)
} else {
encryptedTempBuffer = tempBuffer
}
// Calculate the total length (checksum + packet ID + data)
packetLength := uint32(len(encryptedTempBuffer) + 4)
// Prepare the output buffer with the length field
outBuffer := make([]byte, 4)
binary.LittleEndian.PutUint32(outBuffer, packetLength)
// Append the tempBuffer (checksum + packet ID + data) to outBuffer
outBuffer = append(outBuffer, checksumBytes...)
outBuffer = append(outBuffer, tempBuffer...)
// Send the packet over the connection
writeLen, err := client.conn.Write(outBuffer)
if err != nil {
return err
}
if writeLen != len(outBuffer) {
return errors.New("write length does not match output buffer length")
}
return nil
}
func (client *Device) sendPacket(packetID uint8, data []byte) error {
return sendPacket(packetID, data, client)
}
func (client *Device) receivePacket() (uint8, uint32, []byte, error) {
return receivePacket(client)
}
func updateLastSeen(client *Device) {
if client.conn != nil {
client.lastHeardFrom = time.Now().Unix()
}
}
// receivePacket receives a packet from the provided TCP connection, decrypts it, and verifies it.
func receivePacket(client *Device) (uint8, uint32, []byte, error) {
// Step 1: Read the packet length (4 bytes)
lengthBuffer := make([]byte, 4)
n, err := client.conn.Read(lengthBuffer)
if err != nil {
return 0, 0, nil, err
}
if n != len(lengthBuffer) {
return 0, 0, nil, errors.New("read length does not match length buffer")
}
packetLength := binary.LittleEndian.Uint32(lengthBuffer)
// Step 2: Allocate a buffer to hold the rest of the packet (checksum + packetID + data)
packetBuffer := make([]byte, packetLength)
n, err = client.conn.Read(packetBuffer)
if err != nil {
return 0, 0, nil, err
}
if n != int(packetLength) {
return 0, 0, nil, errors.New("read length does not match packet length")
}
// Step 3: Decrypt the packet buffer
decryptedPacketBuffer := make([]byte, len(packetBuffer))
if client.cipher != nil {
(*client.cipher).Decrypt(decryptedPacketBuffer, packetBuffer)
} else {
decryptedPacketBuffer = packetBuffer
}
// Step 4: Extract the checksum (4 bytes)
receivedChecksum := binary.LittleEndian.Uint32(decryptedPacketBuffer[:4])
// Step 5: Extract the packet ID (1 byte)
packetID := decryptedPacketBuffer[4]
// Step 6: Extract the data
data := decryptedPacketBuffer[5:]
// Step 7: Verify the checksum
calculatedChecksum := crc32.ChecksumIEEE(append([]byte{packetID}, data...))
if receivedChecksum != calculatedChecksum {
return 0, 0, nil, errors.New("checksum mismatch")
}
// Return the packet ID and data
updateLastSeen(client)
return packetID, uint32(len(data)), data, nil
}
func (client *Device) receiveConnectionRequest(data []byte) {
if len(data) < 4 {
clientID := binary.LittleEndian.Uint32(data[0:4])
for _, clientLoop := range storedClients {
if clientLoop.ID == clientID {
client.name = clientLoop.Name
client.id = clientID
if clientLoop.Privileges > 0 {
clientLoop.Type = 0
}
client.device = &Device{
typeID: clientLoop.Type,
channels: make([]uint8, deviceChannelCount[clientLoop.Type]),
client: client,
privileges: clientLoop.Privileges,
}
clientCipher, err := aes.NewCipher(clientLoop.Key)
if err != nil {
fmt.Println("Error creating AES cipher")
return
}
client.cipher = &clientCipher
}
}
}
}
func (client *Device) receiveChannelUpdate(data []byte) {
if len(data) < 2 && len(client.channels) != 0 {
channelID := data[0]
channelValue := data[1]
if channelID >= uint8(len(client.channels)) {
return
}
client.channels[channelID] = channelValue
}
}
func handleRequest(conn *net.TCPConn) {
//serverbound packets:
//0: start connection
//1: reserved
//2: channel value
//3: list devices
//4: set device channel
//5: get device channel
//6: list devices
//7: list admins
//8: add device
//9: add controller
//clientbound packets:
//2: set channel
//3: get channel
//4: channel from device
//5: added device key and id
//6: device entry
//7: admin entry
client := Device{
cipher: nil,
conn: conn,
}
clients = append(clients, &client)
defer func(conn *net.TCPConn) {
_ = conn.Close()
}(conn)
for {
packetID, _, packetData, err := receivePacket(&client)
if err != nil {
return
}
switch packetID {
case 0:
if client.cipher == nil {
client.receiveConnectionRequest(packetData)
} else {
return
}
case 1:
case 2:
if client.cipher != nil {
client.receiveChannelUpdate(packetData)
}
}
}
}