package dns import ( "bytes" "context" "crypto/tls" "fmt" "io" "net/http" "sync" "time" "github.com/miekg/dns" ) // DNSQuery DNS查询记录 type DNSQuery struct { Domain string QueryType string Timestamp time.Time Count int } type DNSServer struct { Port int UpstreamDNS string UseDoH bool DoHURL string queries map[string]*DNSQuery queriesMutex sync.RWMutex server *dns.Server dnsClient *dns.Client httpClient *http.Client } // NewDNSServer 创建DNS服务器 func NewDNSServer(port int, upstreamDNS string) *DNSServer { return &DNSServer{ Port: port, UpstreamDNS: upstreamDNS, UseDoH: false, queries: make(map[string]*DNSQuery), dnsClient: &dns.Client{ Net: "udp", Timeout: 3 * time.Second, DialTimeout: 2 * time.Second, ReadTimeout: 3 * time.Second, WriteTimeout: 2 * time.Second, }, } } // NewDNSServerWithDoH 创建使用DoH的DNS服务器 func NewDNSServerWithDoH(port int, dohURL string) *DNSServer { return &DNSServer{ Port: port, UseDoH: true, DoHURL: dohURL, queries: make(map[string]*DNSQuery), httpClient: &http.Client{ Timeout: 3 * time.Second, Transport: &http.Transport{ ForceAttemptHTTP2: true, MaxIdleConns: 100, MaxIdleConnsPerHost: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 2 * time.Second, DisableCompression: false, TLSClientConfig: &tls.Config{ InsecureSkipVerify: false, }, }, }, } } // Start 启动DNS服务器 func (s *DNSServer) Start() error { s.server = &dns.Server{ Addr: fmt.Sprintf("0.0.0.0:%d", s.Port), Net: "udp", Handler: dns.HandlerFunc(s.handleDNSRequest), UDPSize: 65535, } fmt.Printf("DNS Server started on 0.0.0.0:%d\n", s.Port) if s.UseDoH { fmt.Printf("Using DoH: %s\n", s.DoHURL) } else { fmt.Printf("Upstream DNS: %s\n", s.UpstreamDNS) } go func() { if err := s.server.ListenAndServe(); err != nil { fmt.Printf("DNS server error: %v\n", err) } }() return nil } // Stop 停止DNS服务器 func (s *DNSServer) Stop() { if s.server != nil { s.server.Shutdown() } } // handleDNSRequest 处理DNS请求 func (s *DNSServer) handleDNSRequest(w dns.ResponseWriter, req *dns.Msg) { // 记录查询 if len(req.Question) > 0 { q := req.Question[0] s.recordQuery(q.Name, dns.TypeToString[q.Qtype]) } // 转发请求 var resp *dns.Msg var err error if s.UseDoH { resp, err = s.forwardDoH(req) } else { resp, err = s.forwardUDP(req) } if err != nil { // 返回 SERVFAIL resp = new(dns.Msg) resp.SetRcode(req, dns.RcodeServerFailure) } w.WriteMsg(resp) } // forwardUDP 通过UDP转发 func (s *DNSServer) forwardUDP(req *dns.Msg) (*dns.Msg, error) { resp, _, err := s.dnsClient.Exchange(req, s.UpstreamDNS) return resp, err } // forwardDoH 通过DoH转发 func (s *DNSServer) forwardDoH(req *dns.Msg) (*dns.Msg, error) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() // 将 DNS 消息打包 packed, err := req.Pack() if err != nil { return nil, err } // 创建 HTTP POST 请求 httpReq, err := http.NewRequestWithContext(ctx, "POST", s.DoHURL, bytes.NewReader(packed)) if err != nil { return nil, err } // 设置 DoH 请求头 httpReq.Header.Set("Content-Type", "application/dns-message") httpReq.Header.Set("Accept", "application/dns-message") // 发送请求 resp, err := s.httpClient.Do(httpReq) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("DoH server returned status %d", resp.StatusCode) } // 读取响应 body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } // 解析 DNS 响应 dnsResp := new(dns.Msg) if err := dnsResp.Unpack(body); err != nil { return nil, err } return dnsResp, nil } // recordQuery 记录DNS查询 func (s *DNSServer) recordQuery(domain, queryType string) { s.queriesMutex.Lock() defer s.queriesMutex.Unlock() key := domain if q, exists := s.queries[key]; exists { q.Count++ q.Timestamp = time.Now() } else { s.queries[key] = &DNSQuery{ Domain: domain, QueryType: queryType, Timestamp: time.Now(), Count: 1, } } // 限制缓存大小,防止内存溢出 if len(s.queries) > 500 { // 删除最旧的条目 oldestKey := "" oldestTime := time.Now() for key, q := range s.queries { if q.Timestamp.Before(oldestTime) { oldestTime = q.Timestamp oldestKey = key } } if oldestKey != "" { delete(s.queries, oldestKey) } } } // GetRecentQueries 获取最近的DNS查询 func (s *DNSServer) GetRecentQueries(limit int) []DNSQuery { s.queriesMutex.RLock() defer s.queriesMutex.RUnlock() queries := make([]DNSQuery, 0, len(s.queries)) for _, q := range s.queries { queries = append(queries, *q) } // 按时间排序(最新的在前) for i := 0; i < len(queries)-1; i++ { for j := i + 1; j < len(queries); j++ { if queries[j].Timestamp.After(queries[i].Timestamp) { queries[i], queries[j] = queries[j], queries[i] } } } if len(queries) > limit { queries = queries[:limit] } return queries }