232 lines
5.7 KiB
Go
232 lines
5.7 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"hash/crc32"
|
|
"net"
|
|
"time"
|
|
)
|
|
|
|
type Device struct {
|
|
cipher *cipher.Block
|
|
conn *net.TCPConn
|
|
device *Device
|
|
name string
|
|
id uint32
|
|
typeID uint8
|
|
channels []uint8
|
|
client *Device
|
|
lastHeardFrom int64
|
|
privileges uint8
|
|
}
|
|
|
|
type StoredClient struct {
|
|
Key []byte `json:"key"`
|
|
Name string `json:"name"`
|
|
ID uint32 `json:"id"`
|
|
Type uint8 `json:"type"`
|
|
Privileges uint8 `json:"privileges"`
|
|
}
|
|
|
|
// sendPacket sends a packet with the given packetID and data over the provided TCP connection.
|
|
func sendPacket(packetID uint8, data []byte, client *Device) error {
|
|
// Create a buffer with the packet ID and data
|
|
tempBuffer := append([]byte{packetID}, data...)
|
|
|
|
// Calculate the checksum of the packet ID + data
|
|
checksum := crc32.ChecksumIEEE(tempBuffer)
|
|
|
|
// Convert the checksum to a 4-byte slice
|
|
checksumBytes := make([]byte, 4)
|
|
binary.LittleEndian.PutUint32(checksumBytes, checksum)
|
|
|
|
var encryptedTempBuffer []byte
|
|
|
|
if client.cipher != nil {
|
|
(*client.cipher).Encrypt(encryptedTempBuffer, tempBuffer)
|
|
} else {
|
|
encryptedTempBuffer = tempBuffer
|
|
}
|
|
// Calculate the total length (checksum + packet ID + data)
|
|
packetLength := uint32(len(encryptedTempBuffer) + 4)
|
|
|
|
// Prepare the output buffer with the length field
|
|
outBuffer := make([]byte, 4)
|
|
|
|
binary.LittleEndian.PutUint32(outBuffer, packetLength)
|
|
// Append the tempBuffer (checksum + packet ID + data) to outBuffer
|
|
outBuffer = append(outBuffer, checksumBytes...)
|
|
outBuffer = append(outBuffer, tempBuffer...)
|
|
|
|
// Send the packet over the connection
|
|
writeLen, err := client.conn.Write(outBuffer)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if writeLen != len(outBuffer) {
|
|
return errors.New("write length does not match output buffer length")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (client *Device) sendPacket(packetID uint8, data []byte) error {
|
|
return sendPacket(packetID, data, client)
|
|
}
|
|
|
|
func (client *Device) receivePacket() (uint8, uint32, []byte, error) {
|
|
return receivePacket(client)
|
|
}
|
|
|
|
func updateLastSeen(client *Device) {
|
|
if client.conn != nil {
|
|
client.lastHeardFrom = time.Now().Unix()
|
|
}
|
|
}
|
|
|
|
// receivePacket receives a packet from the provided TCP connection, decrypts it, and verifies it.
|
|
func receivePacket(client *Device) (uint8, uint32, []byte, error) {
|
|
// Step 1: Read the packet length (4 bytes)
|
|
lengthBuffer := make([]byte, 4)
|
|
n, err := client.conn.Read(lengthBuffer)
|
|
if err != nil {
|
|
return 0, 0, nil, err
|
|
}
|
|
if n != len(lengthBuffer) {
|
|
return 0, 0, nil, errors.New("read length does not match length buffer")
|
|
}
|
|
|
|
packetLength := binary.LittleEndian.Uint32(lengthBuffer)
|
|
|
|
// Step 2: Allocate a buffer to hold the rest of the packet (checksum + packetID + data)
|
|
packetBuffer := make([]byte, packetLength)
|
|
n, err = client.conn.Read(packetBuffer)
|
|
if err != nil {
|
|
return 0, 0, nil, err
|
|
}
|
|
if n != int(packetLength) {
|
|
return 0, 0, nil, errors.New("read length does not match packet length")
|
|
}
|
|
|
|
// Step 3: Decrypt the packet buffer
|
|
decryptedPacketBuffer := make([]byte, len(packetBuffer))
|
|
if client.cipher != nil {
|
|
(*client.cipher).Decrypt(decryptedPacketBuffer, packetBuffer)
|
|
} else {
|
|
decryptedPacketBuffer = packetBuffer
|
|
}
|
|
|
|
// Step 4: Extract the checksum (4 bytes)
|
|
receivedChecksum := binary.LittleEndian.Uint32(decryptedPacketBuffer[:4])
|
|
|
|
// Step 5: Extract the packet ID (1 byte)
|
|
packetID := decryptedPacketBuffer[4]
|
|
|
|
// Step 6: Extract the data
|
|
data := decryptedPacketBuffer[5:]
|
|
|
|
// Step 7: Verify the checksum
|
|
calculatedChecksum := crc32.ChecksumIEEE(append([]byte{packetID}, data...))
|
|
if receivedChecksum != calculatedChecksum {
|
|
return 0, 0, nil, errors.New("checksum mismatch")
|
|
}
|
|
|
|
// Return the packet ID and data
|
|
updateLastSeen(client)
|
|
return packetID, uint32(len(data)), data, nil
|
|
}
|
|
|
|
func (client *Device) receiveConnectionRequest(data []byte) {
|
|
if len(data) < 4 {
|
|
clientID := binary.LittleEndian.Uint32(data[0:4])
|
|
for _, clientLoop := range storedClients {
|
|
if clientLoop.ID == clientID {
|
|
client.name = clientLoop.Name
|
|
client.id = clientID
|
|
if clientLoop.Privileges > 0 {
|
|
clientLoop.Type = 0
|
|
}
|
|
client.device = &Device{
|
|
typeID: clientLoop.Type,
|
|
channels: make([]uint8, deviceChannelCount[clientLoop.Type]),
|
|
client: client,
|
|
privileges: clientLoop.Privileges,
|
|
}
|
|
clientCipher, err := aes.NewCipher(clientLoop.Key)
|
|
if err != nil {
|
|
fmt.Println("Error creating AES cipher")
|
|
return
|
|
}
|
|
client.cipher = &clientCipher
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (client *Device) receiveChannelUpdate(data []byte) {
|
|
if len(data) < 2 && len(client.channels) != 0 {
|
|
channelID := data[0]
|
|
channelValue := data[1]
|
|
if channelID >= uint8(len(client.channels)) {
|
|
return
|
|
}
|
|
client.channels[channelID] = channelValue
|
|
}
|
|
}
|
|
|
|
func handleRequest(conn *net.TCPConn) {
|
|
//serverbound packets:
|
|
//0: start connection
|
|
//1: reserved
|
|
//2: channel value
|
|
//3: list devices
|
|
//4: set device channel
|
|
//5: get device channel
|
|
//6: list devices
|
|
//7: list admins
|
|
//8: add device
|
|
//9: add controller
|
|
|
|
//clientbound packets:
|
|
//2: set channel
|
|
//3: get channel
|
|
//4: channel from device
|
|
//5: added device key and id
|
|
//6: device entry
|
|
//7: admin entry
|
|
|
|
client := Device{
|
|
cipher: nil,
|
|
conn: conn,
|
|
}
|
|
clients = append(clients, &client)
|
|
defer func(conn *net.TCPConn) {
|
|
_ = conn.Close()
|
|
}(conn)
|
|
for {
|
|
packetID, _, packetData, err := receivePacket(&client)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
switch packetID {
|
|
case 0:
|
|
if client.cipher == nil {
|
|
client.receiveConnectionRequest(packetData)
|
|
} else {
|
|
return
|
|
}
|
|
case 1:
|
|
|
|
case 2:
|
|
if client.cipher != nil {
|
|
client.receiveChannelUpdate(packetData)
|
|
}
|
|
|
|
}
|
|
}
|
|
}
|