249 lines
5.1 KiB
Go
249 lines
5.1 KiB
Go
package dns
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
// DNSQuery DNS查询记录
|
|
type DNSQuery struct {
|
|
Domain string
|
|
QueryType string
|
|
Timestamp time.Time
|
|
Count int
|
|
}
|
|
|
|
type DNSServer struct {
|
|
Port int
|
|
UpstreamDNS string
|
|
UseDoH bool
|
|
DoHURL string
|
|
queries map[string]*DNSQuery
|
|
queriesMutex sync.RWMutex
|
|
server *dns.Server
|
|
dnsClient *dns.Client
|
|
httpClient *http.Client
|
|
}
|
|
|
|
// NewDNSServer 创建DNS服务器
|
|
func NewDNSServer(port int, upstreamDNS string) *DNSServer {
|
|
return &DNSServer{
|
|
Port: port,
|
|
UpstreamDNS: upstreamDNS,
|
|
UseDoH: false,
|
|
queries: make(map[string]*DNSQuery),
|
|
dnsClient: &dns.Client{
|
|
Net: "udp",
|
|
Timeout: 3 * time.Second,
|
|
DialTimeout: 2 * time.Second,
|
|
ReadTimeout: 3 * time.Second,
|
|
WriteTimeout: 2 * time.Second,
|
|
},
|
|
}
|
|
}
|
|
|
|
// NewDNSServerWithDoH 创建使用DoH的DNS服务器
|
|
func NewDNSServerWithDoH(port int, dohURL string) *DNSServer {
|
|
return &DNSServer{
|
|
Port: port,
|
|
UseDoH: true,
|
|
DoHURL: dohURL,
|
|
queries: make(map[string]*DNSQuery),
|
|
httpClient: &http.Client{
|
|
Timeout: 3 * time.Second,
|
|
Transport: &http.Transport{
|
|
ForceAttemptHTTP2: true,
|
|
MaxIdleConns: 100,
|
|
MaxIdleConnsPerHost: 100,
|
|
IdleConnTimeout: 90 * time.Second,
|
|
TLSHandshakeTimeout: 2 * time.Second,
|
|
DisableCompression: false,
|
|
TLSClientConfig: &tls.Config{
|
|
InsecureSkipVerify: false,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
// Start 启动DNS服务器
|
|
func (s *DNSServer) Start() error {
|
|
s.server = &dns.Server{
|
|
Addr: fmt.Sprintf("0.0.0.0:%d", s.Port),
|
|
Net: "udp",
|
|
Handler: dns.HandlerFunc(s.handleDNSRequest),
|
|
UDPSize: 65535,
|
|
}
|
|
|
|
fmt.Printf("DNS Server started on 0.0.0.0:%d\n", s.Port)
|
|
if s.UseDoH {
|
|
fmt.Printf("Using DoH: %s\n", s.DoHURL)
|
|
} else {
|
|
fmt.Printf("Upstream DNS: %s\n", s.UpstreamDNS)
|
|
}
|
|
|
|
go func() {
|
|
if err := s.server.ListenAndServe(); err != nil {
|
|
fmt.Printf("DNS server error: %v\n", err)
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
// Stop 停止DNS服务器
|
|
func (s *DNSServer) Stop() {
|
|
if s.server != nil {
|
|
s.server.Shutdown()
|
|
}
|
|
}
|
|
|
|
// handleDNSRequest 处理DNS请求
|
|
func (s *DNSServer) handleDNSRequest(w dns.ResponseWriter, req *dns.Msg) {
|
|
// 记录查询
|
|
if len(req.Question) > 0 {
|
|
q := req.Question[0]
|
|
s.recordQuery(q.Name, dns.TypeToString[q.Qtype])
|
|
}
|
|
|
|
// 转发请求
|
|
var resp *dns.Msg
|
|
var err error
|
|
|
|
if s.UseDoH {
|
|
resp, err = s.forwardDoH(req)
|
|
} else {
|
|
resp, err = s.forwardUDP(req)
|
|
}
|
|
|
|
if err != nil {
|
|
// 返回 SERVFAIL
|
|
resp = new(dns.Msg)
|
|
resp.SetRcode(req, dns.RcodeServerFailure)
|
|
}
|
|
|
|
w.WriteMsg(resp)
|
|
}
|
|
|
|
// forwardUDP 通过UDP转发
|
|
func (s *DNSServer) forwardUDP(req *dns.Msg) (*dns.Msg, error) {
|
|
resp, _, err := s.dnsClient.Exchange(req, s.UpstreamDNS)
|
|
return resp, err
|
|
}
|
|
|
|
// forwardDoH 通过DoH转发
|
|
func (s *DNSServer) forwardDoH(req *dns.Msg) (*dns.Msg, error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
|
|
// 将 DNS 消息打包
|
|
packed, err := req.Pack()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 创建 HTTP POST 请求
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", s.DoHURL, bytes.NewReader(packed))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 设置 DoH 请求头
|
|
httpReq.Header.Set("Content-Type", "application/dns-message")
|
|
httpReq.Header.Set("Accept", "application/dns-message")
|
|
|
|
// 发送请求
|
|
resp, err := s.httpClient.Do(httpReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("DoH server returned status %d", resp.StatusCode)
|
|
}
|
|
|
|
// 读取响应
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 解析 DNS 响应
|
|
dnsResp := new(dns.Msg)
|
|
if err := dnsResp.Unpack(body); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return dnsResp, nil
|
|
}
|
|
|
|
// recordQuery 记录DNS查询
|
|
func (s *DNSServer) recordQuery(domain, queryType string) {
|
|
s.queriesMutex.Lock()
|
|
defer s.queriesMutex.Unlock()
|
|
|
|
key := domain
|
|
if q, exists := s.queries[key]; exists {
|
|
q.Count++
|
|
q.Timestamp = time.Now()
|
|
} else {
|
|
s.queries[key] = &DNSQuery{
|
|
Domain: domain,
|
|
QueryType: queryType,
|
|
Timestamp: time.Now(),
|
|
Count: 1,
|
|
}
|
|
}
|
|
|
|
// 限制缓存大小,防止内存溢出
|
|
if len(s.queries) > 500 {
|
|
// 删除最旧的条目
|
|
oldestKey := ""
|
|
oldestTime := time.Now()
|
|
for key, q := range s.queries {
|
|
if q.Timestamp.Before(oldestTime) {
|
|
oldestTime = q.Timestamp
|
|
oldestKey = key
|
|
}
|
|
}
|
|
if oldestKey != "" {
|
|
delete(s.queries, oldestKey)
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetRecentQueries 获取最近的DNS查询
|
|
func (s *DNSServer) GetRecentQueries(limit int) []DNSQuery {
|
|
s.queriesMutex.RLock()
|
|
defer s.queriesMutex.RUnlock()
|
|
|
|
queries := make([]DNSQuery, 0, len(s.queries))
|
|
for _, q := range s.queries {
|
|
queries = append(queries, *q)
|
|
}
|
|
|
|
// 按时间排序(最新的在前)
|
|
for i := 0; i < len(queries)-1; i++ {
|
|
for j := i + 1; j < len(queries); j++ {
|
|
if queries[j].Timestamp.After(queries[i].Timestamp) {
|
|
queries[i], queries[j] = queries[j], queries[i]
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(queries) > limit {
|
|
queries = queries[:limit]
|
|
}
|
|
|
|
return queries
|
|
}
|