From 27d53c87515676543cda28b459b8dadf60e78058 Mon Sep 17 00:00:00 2001 From: lirui Date: Sun, 4 Jan 2026 22:32:56 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96DNS=E6=9C=8D=E5=8A=A1?= =?UTF-8?q?=E5=99=A8=E7=9A=84HTTP=E5=AE=A2=E6=88=B7=E7=AB=AF=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=EF=BC=8C=E8=B0=83=E6=95=B4=E8=B6=85=E6=97=B6=E6=97=B6?= =?UTF-8?q?=E9=97=B4=E5=92=8C=E8=BF=9E=E6=8E=A5=E5=8F=82=E6=95=B0=EF=BC=9B?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=9F=A5=E8=AF=A2=E8=AE=B0=E5=BD=95=E7=BC=93?= =?UTF-8?q?=E5=AD=98=E9=99=90=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 9 ++ go.sum | 12 +++ internal/dns/dns.go | 225 +++++++++++--------------------------------- 3 files changed, 77 insertions(+), 169 deletions(-) create mode 100644 go.sum diff --git a/go.mod b/go.mod index c4c1723..883ab60 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,12 @@ module github.com/meowrain/nodeprobe go 1.25.5 + +require ( + github.com/miekg/dns v1.1.69 // indirect + golang.org/x/mod v0.30.0 // indirect + golang.org/x/net v0.47.0 // indirect + golang.org/x/sync v0.18.0 // indirect + golang.org/x/sys v0.38.0 // indirect + golang.org/x/tools v0.39.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..78957f8 --- /dev/null +++ b/go.sum @@ -0,0 +1,12 @@ +github.com/miekg/dns v1.1.69 h1:Kb7Y/1Jo+SG+a2GtfoFUfDkG//csdRPwRLkCsxDG9Sc= +github.com/miekg/dns v1.1.69/go.mod h1:7OyjD9nEba5OkqQ/hB4fy3PIoxafSZJtducccIelz3g= +golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= +golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= +golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= diff --git a/internal/dns/dns.go b/internal/dns/dns.go index ff8bc1f..8b6c9b1 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -1,15 +1,14 @@ package dns import ( - "bytes" + "context" "crypto/tls" - "encoding/binary" "fmt" - "io" - "net" "net/http" "sync" "time" + + "github.com/miekg/dns" ) // DNSQuery DNS查询记录 @@ -27,7 +26,8 @@ type DNSServer struct { DoHURL string queries map[string]*DNSQuery queriesMutex sync.RWMutex - conn *net.UDPConn + server *dns.Server + dnsClient *dns.Client httpClient *http.Client } @@ -38,17 +38,12 @@ func NewDNSServer(port int, upstreamDNS string) *DNSServer { UpstreamDNS: upstreamDNS, UseDoH: false, queries: make(map[string]*DNSQuery), - httpClient: &http.Client{ - Timeout: 3 * time.Second, - Transport: &http.Transport{ - MaxIdleConns: 100, - MaxIdleConnsPerHost: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 2 * time.Second, - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: false, - }, - }, + dnsClient: &dns.Client{ + Net: "udp", + Timeout: 3 * time.Second, + DialTimeout: 2 * time.Second, + ReadTimeout: 3 * time.Second, + WriteTimeout: 2 * time.Second, }, } } @@ -79,17 +74,13 @@ func NewDNSServerWithDoH(port int, dohURL string) *DNSServer { // Start 启动DNS服务器 func (s *DNSServer) Start() error { - addr := net.UDPAddr{ - Port: s.Port, - IP: net.ParseIP("0.0.0.0"), + s.server = &dns.Server{ + Addr: fmt.Sprintf("0.0.0.0:%d", s.Port), + Net: "udp", + Handler: dns.HandlerFunc(s.handleDNSRequest), + UDPSize: 65535, } - conn, err := net.ListenUDP("udp", &addr) - if err != nil { - return fmt.Errorf("failed to start DNS server: %v", err) - } - - s.conn = conn fmt.Printf("DNS Server started on 0.0.0.0:%d\n", s.Port) if s.UseDoH { fmt.Printf("Using DoH: %s\n", s.DoHURL) @@ -97,175 +88,71 @@ func (s *DNSServer) Start() error { fmt.Printf("Upstream DNS: %s\n", s.UpstreamDNS) } - go s.handleRequests() + go func() { + if err := s.server.ListenAndServe(); err != nil { + fmt.Printf("DNS server error: %v\n", err) + } + }() + return nil } // Stop 停止DNS服务器 func (s *DNSServer) Stop() { - if s.conn != nil { - s.conn.Close() + if s.server != nil { + s.server.Shutdown() } } -// handleRequests 处理DNS请求 -func (s *DNSServer) handleRequests() { - buffer := make([]byte, 512) - - for { - n, clientAddr, err := s.conn.ReadFromUDP(buffer) - if err != nil { - continue - } - - // 解析DNS查询 - domain, queryType := s.parseDNSQuery(buffer[:n]) - if domain != "" { - s.recordQuery(domain, queryType) - } - - // 转发到上游DNS - go s.forwardQuery(buffer[:n], clientAddr) +// handleDNSRequest 处理DNS请求 +func (s *DNSServer) handleDNSRequest(w dns.ResponseWriter, req *dns.Msg) { + // 记录查询 + if len(req.Question) > 0 { + q := req.Question[0] + s.recordQuery(q.Name, dns.TypeToString[q.Qtype]) } -} -// forwardQuery 转发DNS查询到上游服务器 -func (s *DNSServer) forwardQuery(query []byte, clientAddr *net.UDPAddr) { - var response []byte + // 转发请求 + var resp *dns.Msg var err error if s.UseDoH { - response, err = s.forwardQueryDoH(query) + resp, err = s.forwardDoH(req) } else { - response, err = s.forwardQueryUDP(query) + resp, err = s.forwardUDP(req) } if err != nil { - return + // 返回 SERVFAIL + resp = new(dns.Msg) + resp.SetRcode(req, dns.RcodeServerFailure) } - // 返回给客户端 - s.conn.WriteToUDP(response, clientAddr) + w.WriteMsg(resp) } -// forwardQueryUDP 通过传统UDP转发DNS查询 -func (s *DNSServer) forwardQueryUDP(query []byte) ([]byte, error) { - // 连接上游DNS - upstreamAddr, err := net.ResolveUDPAddr("udp", s.UpstreamDNS) - if err != nil { - return nil, err - } - - upstreamConn, err := net.DialUDP("udp", nil, upstreamAddr) - if err != nil { - return nil, err - } - defer upstreamConn.Close() - - // 设置超时 - upstreamConn.SetDeadline(time.Now().Add(3 * time.Second)) - - // 发送查询 - _, err = upstreamConn.Write(query) - if err != nil { - return nil, err - } - - // 接收响应 - response := make([]byte, 512) - n, err := upstreamConn.Read(response) - if err != nil { - return nil, err - } - - return response[:n], nil +// forwardUDP 通过UDP转发 +func (s *DNSServer) forwardUDP(req *dns.Msg) (*dns.Msg, error) { + resp, _, err := s.dnsClient.Exchange(req, s.UpstreamDNS) + return resp, err } -// forwardQueryDoH 通过DoH转发DNS查询 -func (s *DNSServer) forwardQueryDoH(query []byte) ([]byte, error) { - // 创建HTTP POST请求 - req, err := http.NewRequest("POST", s.DoHURL, bytes.NewReader(query)) - if err != nil { - return nil, err +// forwardDoH 通过DoH转发 +func (s *DNSServer) forwardDoH(req *dns.Msg) (*dns.Msg, error) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + // 使用 DNS-over-HTTPS 客户端 + client := &dns.Client{ + Net: "https", + TLSConfig: &tls.Config{ + InsecureSkipVerify: false, + }, + Timeout: 3 * time.Second, } - // 设置DoH请求头 - req.Header.Set("Content-Type", "application/dns-message") - req.Header.Set("Accept", "application/dns-message") - - // 发送请求 - resp, err := s.httpClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("DoH server returned status %d", resp.StatusCode) - } - - // 读取响应 - response, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - return response, nil -} - -// parseDNSQuery 解析DNS查询包 -func (s *DNSServer) parseDNSQuery(data []byte) (string, string) { - if len(data) < 12 { - return "", "" - } - - // DNS header is 12 bytes - // Questions start at byte 12 - offset := 12 - domain := "" - - for offset < len(data) { - length := int(data[offset]) - if length == 0 { - break - } - if length > 63 { - // Pointer, stop parsing - break - } - offset++ - if offset+length > len(data) { - break - } - if domain != "" { - domain += "." - } - domain += string(data[offset : offset+length]) - offset += length - } - - // Parse query type - offset++ // Skip null terminator - queryType := "A" - if offset+2 <= len(data) { - qtype := binary.BigEndian.Uint16(data[offset : offset+2]) - switch qtype { - case 1: - queryType = "A" - case 28: - queryType = "AAAA" - case 5: - queryType = "CNAME" - case 15: - queryType = "MX" - case 16: - queryType = "TXT" - default: - queryType = fmt.Sprintf("TYPE%d", qtype) - } - } - - return domain, queryType + resp, _, err := client.ExchangeContext(ctx, req, s.DoHURL) + return resp, err } // recordQuery 记录DNS查询