package dns import ( "bytes" "crypto/tls" "encoding/binary" "fmt" "io" "net" "net/http" "sync" "time" ) // DNSQuery DNS查询记录 type DNSQuery struct { Domain string QueryType string Timestamp time.Time Count int } type DNSServer struct { Port int UpstreamDNS string UseDoH bool DoHURL string queries map[string]*DNSQuery queriesMutex sync.RWMutex conn *net.UDPConn httpClient *http.Client } // NewDNSServer 创建DNS服务器 func NewDNSServer(port int, upstreamDNS string) *DNSServer { return &DNSServer{ Port: port, UpstreamDNS: upstreamDNS, UseDoH: false, queries: make(map[string]*DNSQuery), httpClient: &http.Client{ Timeout: 5 * time.Second, Transport: &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: false, }, }, }, } } // NewDNSServerWithDoH 创建使用DoH的DNS服务器 func NewDNSServerWithDoH(port int, dohURL string) *DNSServer { return &DNSServer{ Port: port, UseDoH: true, DoHURL: dohURL, queries: make(map[string]*DNSQuery), httpClient: &http.Client{ Timeout: 5 * time.Second, Transport: &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: false, }, }, }, } } // Start 启动DNS服务器 func (s *DNSServer) Start() error { addr := net.UDPAddr{ Port: s.Port, IP: net.ParseIP("0.0.0.0"), } conn, err := net.ListenUDP("udp", &addr) if err != nil { return fmt.Errorf("failed to start DNS server: %v", err) } s.conn = conn fmt.Printf("DNS Server started on 0.0.0.0:%d\n", s.Port) if s.UseDoH { fmt.Printf("Using DoH: %s\n", s.DoHURL) } else { fmt.Printf("Upstream DNS: %s\n", s.UpstreamDNS) } go s.handleRequests() return nil } // Stop 停止DNS服务器 func (s *DNSServer) Stop() { if s.conn != nil { s.conn.Close() } } // handleRequests 处理DNS请求 func (s *DNSServer) handleRequests() { buffer := make([]byte, 512) for { n, clientAddr, err := s.conn.ReadFromUDP(buffer) if err != nil { continue } // 解析DNS查询 domain, queryType := s.parseDNSQuery(buffer[:n]) if domain != "" { s.recordQuery(domain, queryType) } // 转发到上游DNS go s.forwardQuery(buffer[:n], clientAddr) } } // forwardQuery 转发DNS查询到上游服务器 func (s *DNSServer) forwardQuery(query []byte, clientAddr *net.UDPAddr) { var response []byte var err error if s.UseDoH { response, err = s.forwardQueryDoH(query) } else { response, err = s.forwardQueryUDP(query) } if err != nil { return } // 返回给客户端 s.conn.WriteToUDP(response, clientAddr) } // forwardQueryUDP 通过传统UDP转发DNS查询 func (s *DNSServer) forwardQueryUDP(query []byte) ([]byte, error) { // 连接上游DNS upstreamAddr, err := net.ResolveUDPAddr("udp", s.UpstreamDNS) if err != nil { return nil, err } upstreamConn, err := net.DialUDP("udp", nil, upstreamAddr) if err != nil { return nil, err } defer upstreamConn.Close() // 设置超时 upstreamConn.SetDeadline(time.Now().Add(5 * time.Second)) // 发送查询 _, err = upstreamConn.Write(query) if err != nil { return nil, err } // 接收响应 response := make([]byte, 512) n, err := upstreamConn.Read(response) if err != nil { return nil, err } return response[:n], nil } // forwardQueryDoH 通过DoH转发DNS查询 func (s *DNSServer) forwardQueryDoH(query []byte) ([]byte, error) { // 创建HTTP POST请求 req, err := http.NewRequest("POST", s.DoHURL, bytes.NewReader(query)) if err != nil { return nil, err } // 设置DoH请求头 req.Header.Set("Content-Type", "application/dns-message") req.Header.Set("Accept", "application/dns-message") // 发送请求 resp, err := s.httpClient.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("DoH server returned status %d", resp.StatusCode) } // 读取响应 response, err := io.ReadAll(resp.Body) if err != nil { return nil, err } return response, nil } // parseDNSQuery 解析DNS查询包 func (s *DNSServer) parseDNSQuery(data []byte) (string, string) { if len(data) < 12 { return "", "" } // DNS header is 12 bytes // Questions start at byte 12 offset := 12 domain := "" for offset < len(data) { length := int(data[offset]) if length == 0 { break } if length > 63 { // Pointer, stop parsing break } offset++ if offset+length > len(data) { break } if domain != "" { domain += "." } domain += string(data[offset : offset+length]) offset += length } // Parse query type offset++ // Skip null terminator queryType := "A" if offset+2 <= len(data) { qtype := binary.BigEndian.Uint16(data[offset : offset+2]) switch qtype { case 1: queryType = "A" case 28: queryType = "AAAA" case 5: queryType = "CNAME" case 15: queryType = "MX" case 16: queryType = "TXT" default: queryType = fmt.Sprintf("TYPE%d", qtype) } } return domain, queryType } // recordQuery 记录DNS查询 func (s *DNSServer) recordQuery(domain, queryType string) { s.queriesMutex.Lock() defer s.queriesMutex.Unlock() key := domain if q, exists := s.queries[key]; exists { q.Count++ q.Timestamp = time.Now() } else { s.queries[key] = &DNSQuery{ Domain: domain, QueryType: queryType, Timestamp: time.Now(), Count: 1, } } // 限制缓存大小 if len(s.queries) > 200 { // 删除最旧的条目 oldestKey := "" oldestTime := time.Now() for key, q := range s.queries { if q.Timestamp.Before(oldestTime) { oldestTime = q.Timestamp oldestKey = key } } if oldestKey != "" { delete(s.queries, oldestKey) } } } // GetRecentQueries 获取最近的DNS查询 func (s *DNSServer) GetRecentQueries(limit int) []DNSQuery { s.queriesMutex.RLock() defer s.queriesMutex.RUnlock() queries := make([]DNSQuery, 0, len(s.queries)) for _, q := range s.queries { queries = append(queries, *q) } // 按时间排序(最新的在前) for i := 0; i < len(queries)-1; i++ { for j := i + 1; j < len(queries); j++ { if queries[j].Timestamp.After(queries[i].Timestamp) { queries[i], queries[j] = queries[j], queries[i] } } } if len(queries) > limit { queries = queries[:limit] } return queries }