619 lines
17 KiB
Go
Raw Permalink Normal View History

2021-12-04 16:42:11 +00:00
package kcp
import (
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/sha1"
"unsafe"
xor "github.com/templexxx/xorsimd"
"github.com/tjfoc/gmsm/sm4"
"golang.org/x/crypto/blowfish"
"golang.org/x/crypto/cast5"
"golang.org/x/crypto/pbkdf2"
"golang.org/x/crypto/salsa20"
"golang.org/x/crypto/tea"
"golang.org/x/crypto/twofish"
"golang.org/x/crypto/xtea"
)
var (
initialVector = []byte{167, 115, 79, 156, 18, 172, 27, 1, 164, 21, 242, 193, 252, 120, 230, 107}
saltxor = `sH3CIVoF#rWLtJo6`
)
// BlockCrypt defines encryption/decryption methods for a given byte slice.
// Notes on implementing: the data to be encrypted contains a builtin
// nonce at the first 16 bytes
type BlockCrypt interface {
// Encrypt encrypts the whole block in src into dst.
// Dst and src may point at the same memory.
Encrypt(dst, src []byte)
// Decrypt decrypts the whole block in src into dst.
// Dst and src may point at the same memory.
Decrypt(dst, src []byte)
}
type salsa20BlockCrypt struct {
key [32]byte
}
// NewSalsa20BlockCrypt https://en.wikipedia.org/wiki/Salsa20
func NewSalsa20BlockCrypt(key []byte) (BlockCrypt, error) {
c := new(salsa20BlockCrypt)
copy(c.key[:], key)
return c, nil
}
func (c *salsa20BlockCrypt) Encrypt(dst, src []byte) {
salsa20.XORKeyStream(dst[8:], src[8:], src[:8], &c.key)
copy(dst[:8], src[:8])
}
func (c *salsa20BlockCrypt) Decrypt(dst, src []byte) {
salsa20.XORKeyStream(dst[8:], src[8:], src[:8], &c.key)
copy(dst[:8], src[:8])
}
type sm4BlockCrypt struct {
encbuf [sm4.BlockSize]byte // 64bit alignment enc/dec buffer
decbuf [2 * sm4.BlockSize]byte
block cipher.Block
}
// NewSM4BlockCrypt https://github.com/tjfoc/gmsm/tree/master/sm4
func NewSM4BlockCrypt(key []byte) (BlockCrypt, error) {
c := new(sm4BlockCrypt)
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *sm4BlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *sm4BlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type twofishBlockCrypt struct {
encbuf [twofish.BlockSize]byte
decbuf [2 * twofish.BlockSize]byte
block cipher.Block
}
// NewTwofishBlockCrypt https://en.wikipedia.org/wiki/Twofish
func NewTwofishBlockCrypt(key []byte) (BlockCrypt, error) {
c := new(twofishBlockCrypt)
block, err := twofish.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *twofishBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *twofishBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type tripleDESBlockCrypt struct {
encbuf [des.BlockSize]byte
decbuf [2 * des.BlockSize]byte
block cipher.Block
}
// NewTripleDESBlockCrypt https://en.wikipedia.org/wiki/Triple_DES
func NewTripleDESBlockCrypt(key []byte) (BlockCrypt, error) {
c := new(tripleDESBlockCrypt)
block, err := des.NewTripleDESCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *tripleDESBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *tripleDESBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type cast5BlockCrypt struct {
encbuf [cast5.BlockSize]byte
decbuf [2 * cast5.BlockSize]byte
block cipher.Block
}
// NewCast5BlockCrypt https://en.wikipedia.org/wiki/CAST-128
func NewCast5BlockCrypt(key []byte) (BlockCrypt, error) {
c := new(cast5BlockCrypt)
block, err := cast5.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *cast5BlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *cast5BlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type blowfishBlockCrypt struct {
encbuf [blowfish.BlockSize]byte
decbuf [2 * blowfish.BlockSize]byte
block cipher.Block
}
// NewBlowfishBlockCrypt https://en.wikipedia.org/wiki/Blowfish_(cipher)
func NewBlowfishBlockCrypt(key []byte) (BlockCrypt, error) {
c := new(blowfishBlockCrypt)
block, err := blowfish.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *blowfishBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *blowfishBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type aesBlockCrypt struct {
encbuf [aes.BlockSize]byte
decbuf [2 * aes.BlockSize]byte
block cipher.Block
}
// NewAESBlockCrypt https://en.wikipedia.org/wiki/Advanced_Encryption_Standard
func NewAESBlockCrypt(key []byte) (BlockCrypt, error) {
c := new(aesBlockCrypt)
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *aesBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *aesBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type teaBlockCrypt struct {
encbuf [tea.BlockSize]byte
decbuf [2 * tea.BlockSize]byte
block cipher.Block
}
// NewTEABlockCrypt https://en.wikipedia.org/wiki/Tiny_Encryption_Algorithm
func NewTEABlockCrypt(key []byte) (BlockCrypt, error) {
c := new(teaBlockCrypt)
block, err := tea.NewCipherWithRounds(key, 16)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *teaBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *teaBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type xteaBlockCrypt struct {
encbuf [xtea.BlockSize]byte
decbuf [2 * xtea.BlockSize]byte
block cipher.Block
}
// NewXTEABlockCrypt https://en.wikipedia.org/wiki/XTEA
func NewXTEABlockCrypt(key []byte) (BlockCrypt, error) {
c := new(xteaBlockCrypt)
block, err := xtea.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *xteaBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *xteaBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type simpleXORBlockCrypt struct {
xortbl []byte
}
// NewSimpleXORBlockCrypt simple xor with key expanding
func NewSimpleXORBlockCrypt(key []byte) (BlockCrypt, error) {
c := new(simpleXORBlockCrypt)
c.xortbl = pbkdf2.Key(key, []byte(saltxor), 32, mtuLimit, sha1.New)
return c, nil
}
func (c *simpleXORBlockCrypt) Encrypt(dst, src []byte) { xor.Bytes(dst, src, c.xortbl) }
func (c *simpleXORBlockCrypt) Decrypt(dst, src []byte) { xor.Bytes(dst, src, c.xortbl) }
type noneBlockCrypt struct{}
// NewNoneBlockCrypt does nothing but copying
func NewNoneBlockCrypt(key []byte) (BlockCrypt, error) {
return new(noneBlockCrypt), nil
}
func (c *noneBlockCrypt) Encrypt(dst, src []byte) { copy(dst, src) }
func (c *noneBlockCrypt) Decrypt(dst, src []byte) { copy(dst, src) }
// packet encryption with local CFB mode
func encrypt(block cipher.Block, dst, src, buf []byte) {
switch block.BlockSize() {
case 8:
encrypt8(block, dst, src, buf)
case 16:
encrypt16(block, dst, src, buf)
default:
panic("unsupported cipher block size")
}
}
// optimized encryption for the ciphers which works in 8-bytes
func encrypt8(block cipher.Block, dst, src, buf []byte) {
tbl := buf[:8]
block.Encrypt(tbl, initialVector)
n := len(src) / 8
base := 0
repeat := n / 8
left := n % 8
ptr_tbl := (*uint64)(unsafe.Pointer(&tbl[0]))
for i := 0; i < repeat; i++ {
s := src[base:][0:64]
d := dst[base:][0:64]
// 1
*(*uint64)(unsafe.Pointer(&d[0])) = *(*uint64)(unsafe.Pointer(&s[0])) ^ *ptr_tbl
block.Encrypt(tbl, d[0:8])
// 2
*(*uint64)(unsafe.Pointer(&d[8])) = *(*uint64)(unsafe.Pointer(&s[8])) ^ *ptr_tbl
block.Encrypt(tbl, d[8:16])
// 3
*(*uint64)(unsafe.Pointer(&d[16])) = *(*uint64)(unsafe.Pointer(&s[16])) ^ *ptr_tbl
block.Encrypt(tbl, d[16:24])
// 4
*(*uint64)(unsafe.Pointer(&d[24])) = *(*uint64)(unsafe.Pointer(&s[24])) ^ *ptr_tbl
block.Encrypt(tbl, d[24:32])
// 5
*(*uint64)(unsafe.Pointer(&d[32])) = *(*uint64)(unsafe.Pointer(&s[32])) ^ *ptr_tbl
block.Encrypt(tbl, d[32:40])
// 6
*(*uint64)(unsafe.Pointer(&d[40])) = *(*uint64)(unsafe.Pointer(&s[40])) ^ *ptr_tbl
block.Encrypt(tbl, d[40:48])
// 7
*(*uint64)(unsafe.Pointer(&d[48])) = *(*uint64)(unsafe.Pointer(&s[48])) ^ *ptr_tbl
block.Encrypt(tbl, d[48:56])
// 8
*(*uint64)(unsafe.Pointer(&d[56])) = *(*uint64)(unsafe.Pointer(&s[56])) ^ *ptr_tbl
block.Encrypt(tbl, d[56:64])
base += 64
}
switch left {
case 7:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 6:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 5:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 4:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 3:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 2:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 1:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 0:
xorBytes(dst[base:], src[base:], tbl)
}
}
// optimized encryption for the ciphers which works in 16-bytes
func encrypt16(block cipher.Block, dst, src, buf []byte) {
tbl := buf[:16]
block.Encrypt(tbl, initialVector)
n := len(src) / 16
base := 0
repeat := n / 8
left := n % 8
for i := 0; i < repeat; i++ {
s := src[base:][0:128]
d := dst[base:][0:128]
// 1
xor.Bytes16Align(d[0:16], s[0:16], tbl)
block.Encrypt(tbl, d[0:16])
// 2
xor.Bytes16Align(d[16:32], s[16:32], tbl)
block.Encrypt(tbl, d[16:32])
// 3
xor.Bytes16Align(d[32:48], s[32:48], tbl)
block.Encrypt(tbl, d[32:48])
// 4
xor.Bytes16Align(d[48:64], s[48:64], tbl)
block.Encrypt(tbl, d[48:64])
// 5
xor.Bytes16Align(d[64:80], s[64:80], tbl)
block.Encrypt(tbl, d[64:80])
// 6
xor.Bytes16Align(d[80:96], s[80:96], tbl)
block.Encrypt(tbl, d[80:96])
// 7
xor.Bytes16Align(d[96:112], s[96:112], tbl)
block.Encrypt(tbl, d[96:112])
// 8
xor.Bytes16Align(d[112:128], s[112:128], tbl)
block.Encrypt(tbl, d[112:128])
base += 128
}
switch left {
case 7:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 6:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 5:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 4:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 3:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 2:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 1:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 0:
xorBytes(dst[base:], src[base:], tbl)
}
}
// decryption
func decrypt(block cipher.Block, dst, src, buf []byte) {
switch block.BlockSize() {
case 8:
decrypt8(block, dst, src, buf)
case 16:
decrypt16(block, dst, src, buf)
default:
panic("unsupported cipher block size")
}
}
// decrypt 8 bytes block, all byte slices are supposed to be 64bit aligned
func decrypt8(block cipher.Block, dst, src, buf []byte) {
tbl := buf[0:8]
next := buf[8:16]
block.Encrypt(tbl, initialVector)
n := len(src) / 8
base := 0
repeat := n / 8
left := n % 8
ptr_tbl := (*uint64)(unsafe.Pointer(&tbl[0]))
ptr_next := (*uint64)(unsafe.Pointer(&next[0]))
for i := 0; i < repeat; i++ {
s := src[base:][0:64]
d := dst[base:][0:64]
// 1
block.Encrypt(next, s[0:8])
*(*uint64)(unsafe.Pointer(&d[0])) = *(*uint64)(unsafe.Pointer(&s[0])) ^ *ptr_tbl
// 2
block.Encrypt(tbl, s[8:16])
*(*uint64)(unsafe.Pointer(&d[8])) = *(*uint64)(unsafe.Pointer(&s[8])) ^ *ptr_next
// 3
block.Encrypt(next, s[16:24])
*(*uint64)(unsafe.Pointer(&d[16])) = *(*uint64)(unsafe.Pointer(&s[16])) ^ *ptr_tbl
// 4
block.Encrypt(tbl, s[24:32])
*(*uint64)(unsafe.Pointer(&d[24])) = *(*uint64)(unsafe.Pointer(&s[24])) ^ *ptr_next
// 5
block.Encrypt(next, s[32:40])
*(*uint64)(unsafe.Pointer(&d[32])) = *(*uint64)(unsafe.Pointer(&s[32])) ^ *ptr_tbl
// 6
block.Encrypt(tbl, s[40:48])
*(*uint64)(unsafe.Pointer(&d[40])) = *(*uint64)(unsafe.Pointer(&s[40])) ^ *ptr_next
// 7
block.Encrypt(next, s[48:56])
*(*uint64)(unsafe.Pointer(&d[48])) = *(*uint64)(unsafe.Pointer(&s[48])) ^ *ptr_tbl
// 8
block.Encrypt(tbl, s[56:64])
*(*uint64)(unsafe.Pointer(&d[56])) = *(*uint64)(unsafe.Pointer(&s[56])) ^ *ptr_next
base += 64
}
switch left {
case 7:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 6:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 5:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 4:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 3:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 2:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 1:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 0:
xorBytes(dst[base:], src[base:], tbl)
}
}
func decrypt16(block cipher.Block, dst, src, buf []byte) {
tbl := buf[0:16]
next := buf[16:32]
block.Encrypt(tbl, initialVector)
n := len(src) / 16
base := 0
repeat := n / 8
left := n % 8
for i := 0; i < repeat; i++ {
s := src[base:][0:128]
d := dst[base:][0:128]
// 1
block.Encrypt(next, s[0:16])
xor.Bytes16Align(d[0:16], s[0:16], tbl)
// 2
block.Encrypt(tbl, s[16:32])
xor.Bytes16Align(d[16:32], s[16:32], next)
// 3
block.Encrypt(next, s[32:48])
xor.Bytes16Align(d[32:48], s[32:48], tbl)
// 4
block.Encrypt(tbl, s[48:64])
xor.Bytes16Align(d[48:64], s[48:64], next)
// 5
block.Encrypt(next, s[64:80])
xor.Bytes16Align(d[64:80], s[64:80], tbl)
// 6
block.Encrypt(tbl, s[80:96])
xor.Bytes16Align(d[80:96], s[80:96], next)
// 7
block.Encrypt(next, s[96:112])
xor.Bytes16Align(d[96:112], s[96:112], tbl)
// 8
block.Encrypt(tbl, s[112:128])
xor.Bytes16Align(d[112:128], s[112:128], next)
base += 128
}
switch left {
case 7:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 6:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 5:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 4:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 3:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 2:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 1:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 0:
xorBytes(dst[base:], src[base:], tbl)
}
}
// per bytes xors
func xorBytes(dst, a, b []byte) int {
n := len(a)
if len(b) < n {
n = len(b)
}
if n == 0 {
return 0
}
for i := 0; i < n; i++ {
dst[i] = a[i] ^ b[i]
}
return n
}