Files
nodeprobe/internal/dns/dns.go
2026-01-04 22:18:43 +08:00

321 lines
6.2 KiB
Go

package dns
import (
"bytes"
"crypto/tls"
"encoding/binary"
"fmt"
"io"
"net"
"net/http"
"sync"
"time"
)
// DNSQuery DNS查询记录
type DNSQuery struct {
Domain string
QueryType string
Timestamp time.Time
Count int
}
type DNSServer struct {
Port int
UpstreamDNS string
UseDoH bool
DoHURL string
queries map[string]*DNSQuery
queriesMutex sync.RWMutex
conn *net.UDPConn
httpClient *http.Client
}
// NewDNSServer 创建DNS服务器
func NewDNSServer(port int, upstreamDNS string) *DNSServer {
return &DNSServer{
Port: port,
UpstreamDNS: upstreamDNS,
UseDoH: false,
queries: make(map[string]*DNSQuery),
httpClient: &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: false,
},
},
},
}
}
// NewDNSServerWithDoH 创建使用DoH的DNS服务器
func NewDNSServerWithDoH(port int, dohURL string) *DNSServer {
return &DNSServer{
Port: port,
UseDoH: true,
DoHURL: dohURL,
queries: make(map[string]*DNSQuery),
httpClient: &http.Client{
Timeout: 5 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: false,
},
},
},
}
}
// Start 启动DNS服务器
func (s *DNSServer) Start() error {
addr := net.UDPAddr{
Port: s.Port,
IP: net.ParseIP("0.0.0.0"),
}
conn, err := net.ListenUDP("udp", &addr)
if err != nil {
return fmt.Errorf("failed to start DNS server: %v", err)
}
s.conn = conn
fmt.Printf("DNS Server started on 0.0.0.0:%d\n", s.Port)
if s.UseDoH {
fmt.Printf("Using DoH: %s\n", s.DoHURL)
} else {
fmt.Printf("Upstream DNS: %s\n", s.UpstreamDNS)
}
go s.handleRequests()
return nil
}
// Stop 停止DNS服务器
func (s *DNSServer) Stop() {
if s.conn != nil {
s.conn.Close()
}
}
// handleRequests 处理DNS请求
func (s *DNSServer) handleRequests() {
buffer := make([]byte, 512)
for {
n, clientAddr, err := s.conn.ReadFromUDP(buffer)
if err != nil {
continue
}
// 解析DNS查询
domain, queryType := s.parseDNSQuery(buffer[:n])
if domain != "" {
s.recordQuery(domain, queryType)
}
// 转发到上游DNS
go s.forwardQuery(buffer[:n], clientAddr)
}
}
// forwardQuery 转发DNS查询到上游服务器
func (s *DNSServer) forwardQuery(query []byte, clientAddr *net.UDPAddr) {
var response []byte
var err error
if s.UseDoH {
response, err = s.forwardQueryDoH(query)
} else {
response, err = s.forwardQueryUDP(query)
}
if err != nil {
return
}
// 返回给客户端
s.conn.WriteToUDP(response, clientAddr)
}
// forwardQueryUDP 通过传统UDP转发DNS查询
func (s *DNSServer) forwardQueryUDP(query []byte) ([]byte, error) {
// 连接上游DNS
upstreamAddr, err := net.ResolveUDPAddr("udp", s.UpstreamDNS)
if err != nil {
return nil, err
}
upstreamConn, err := net.DialUDP("udp", nil, upstreamAddr)
if err != nil {
return nil, err
}
defer upstreamConn.Close()
// 设置超时
upstreamConn.SetDeadline(time.Now().Add(5 * time.Second))
// 发送查询
_, err = upstreamConn.Write(query)
if err != nil {
return nil, err
}
// 接收响应
response := make([]byte, 512)
n, err := upstreamConn.Read(response)
if err != nil {
return nil, err
}
return response[:n], nil
}
// forwardQueryDoH 通过DoH转发DNS查询
func (s *DNSServer) forwardQueryDoH(query []byte) ([]byte, error) {
// 创建HTTP POST请求
req, err := http.NewRequest("POST", s.DoHURL, bytes.NewReader(query))
if err != nil {
return nil, err
}
// 设置DoH请求头
req.Header.Set("Content-Type", "application/dns-message")
req.Header.Set("Accept", "application/dns-message")
// 发送请求
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("DoH server returned status %d", resp.StatusCode)
}
// 读取响应
response, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
return response, nil
}
// parseDNSQuery 解析DNS查询包
func (s *DNSServer) parseDNSQuery(data []byte) (string, string) {
if len(data) < 12 {
return "", ""
}
// DNS header is 12 bytes
// Questions start at byte 12
offset := 12
domain := ""
for offset < len(data) {
length := int(data[offset])
if length == 0 {
break
}
if length > 63 {
// Pointer, stop parsing
break
}
offset++
if offset+length > len(data) {
break
}
if domain != "" {
domain += "."
}
domain += string(data[offset : offset+length])
offset += length
}
// Parse query type
offset++ // Skip null terminator
queryType := "A"
if offset+2 <= len(data) {
qtype := binary.BigEndian.Uint16(data[offset : offset+2])
switch qtype {
case 1:
queryType = "A"
case 28:
queryType = "AAAA"
case 5:
queryType = "CNAME"
case 15:
queryType = "MX"
case 16:
queryType = "TXT"
default:
queryType = fmt.Sprintf("TYPE%d", qtype)
}
}
return domain, queryType
}
// recordQuery 记录DNS查询
func (s *DNSServer) recordQuery(domain, queryType string) {
s.queriesMutex.Lock()
defer s.queriesMutex.Unlock()
key := domain
if q, exists := s.queries[key]; exists {
q.Count++
q.Timestamp = time.Now()
} else {
s.queries[key] = &DNSQuery{
Domain: domain,
QueryType: queryType,
Timestamp: time.Now(),
Count: 1,
}
}
// 限制缓存大小
if len(s.queries) > 200 {
// 删除最旧的条目
oldestKey := ""
oldestTime := time.Now()
for key, q := range s.queries {
if q.Timestamp.Before(oldestTime) {
oldestTime = q.Timestamp
oldestKey = key
}
}
if oldestKey != "" {
delete(s.queries, oldestKey)
}
}
}
// GetRecentQueries 获取最近的DNS查询
func (s *DNSServer) GetRecentQueries(limit int) []DNSQuery {
s.queriesMutex.RLock()
defer s.queriesMutex.RUnlock()
queries := make([]DNSQuery, 0, len(s.queries))
for _, q := range s.queries {
queries = append(queries, *q)
}
// 按时间排序(最新的在前)
for i := 0; i < len(queries)-1; i++ {
for j := i + 1; j < len(queries); j++ {
if queries[j].Timestamp.After(queries[i].Timestamp) {
queries[i], queries[j] = queries[j], queries[i]
}
}
}
if len(queries) > limit {
queries = queries[:limit]
}
return queries
}