Content-Type auto detection for client
Content-Type auto detection for client

--- a/doh-client/client.go
+++ b/doh-client/client.go
@@ -30,6 +30,7 @@
 	"net"
 	"net/http"
 	"net/http/cookiejar"
+	"strings"
 	"time"
 
 	"../json-dns"
@@ -44,6 +45,15 @@
 	tcpServer     *dns.Server
 	httpTransport *http.Transport
 	httpClient    *http.Client
+}
+
+type DNSRequest struct {
+	response          *http.Response
+	reply             *dns.Msg
+	udpSize           uint16
+	ednsClientAddress net.IP
+	ednsClientNetmask uint8
+	err               error
 }
 
 func NewClient(conf *config) (c *Client, err error) {
@@ -144,20 +154,50 @@
 		return
 	}
 
+	requestType := ""
 	if len(c.conf.UpstreamIETF) == 0 {
-		c.handlerFuncGoogle(w, r, isTCP)
-		return
-	}
-	if len(c.conf.UpstreamGoogle) == 0 {
-		c.handlerFuncIETF(w, r, isTCP)
-		return
-	}
-	numServers := len(c.conf.UpstreamGoogle) + len(c.conf.UpstreamIETF)
-	random := rand.Intn(numServers)
-	if random < len(c.conf.UpstreamGoogle) {
-		c.handlerFuncGoogle(w, r, isTCP)
-	} else {
-		c.handlerFuncIETF(w, r, isTCP)
+		requestType = "application/x-www-form-urlencoded"
+	} else if len(c.conf.UpstreamGoogle) == 0 {
+		requestType = "application/dns-udpwireformat"
+	} else {
+		numServers := len(c.conf.UpstreamGoogle) + len(c.conf.UpstreamIETF)
+		random := rand.Intn(numServers)
+		if random < len(c.conf.UpstreamGoogle) {
+			requestType = "application/x-www-form-urlencoded"
+		} else {
+			requestType = "application/dns-udpwireformat"
+		}
+	}
+
+	var req *DNSRequest
+	if requestType == "application/x-www-form-urlencoded" {
+		req = c.generateRequestGoogle(w, r, isTCP)
+	} else if requestType == "application/dns-udpwireformat" {
+		req = c.generateRequestIETF(w, r, isTCP)
+	} else {
+		panic("Unknown request Content-Type")
+	}
+
+	contentType := ""
+	candidateType := strings.SplitN(req.response.Header.Get("Content-Type"), ";", 2)[0]
+	if candidateType == "application/json" {
+		contentType = "application/json"
+	} else if candidateType == "application/dns-udpwireformat" {
+		contentType = "application/dns-udpwireformat"
+	} else {
+		if requestType == "application/x-www-form-urlencoded" {
+			contentType = "application/json"
+		} else if requestType == "application/dns-udpwireformat" {
+			contentType = "application/dns-udpwireformat"
+		}
+	}
+
+	if contentType == "application/json" {
+		c.parseResponseGoogle(w, r, isTCP, req)
+	} else if contentType == "application/dns-udpwireformat" {
+		c.parseResponseIETF(w, r, isTCP, req)
+	} else {
+		panic("Unknown response Content-Type")
 	}
 }
 

--- a/doh-client/google.go
+++ b/doh-client/google.go
@@ -39,14 +39,16 @@
 	"github.com/miekg/dns"
 )
 
-func (c *Client) handlerFuncGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
+func (c *Client) generateRequestGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest {
 	reply := jsonDNS.PrepareReply(r)
 
 	if len(r.Question) != 1 {
 		log.Println("Number of questions is not 1")
 		reply.Rcode = dns.RcodeFormatError
 		w.WriteMsg(reply)
-		return
+		return &DNSRequest{
+			err: &dns.Error{},
+		}
 	}
 	question := &r.Question[0]
 	// knot-resolver scrambles capitalization, I think it is unfriendly to cache
@@ -85,9 +87,11 @@
 		log.Println(err)
 		reply.Rcode = dns.RcodeServerFailure
 		w.WriteMsg(reply)
-		return
+		return &DNSRequest{
+			err: err,
+		}
 	}
-	req.Header.Set("Accept", "application/json")
+	req.Header.Set("Accept", "application/json, application/dns-udpwireformat")
 	req.Header.Set("User-Agent", "DNS-over-HTTPS/1.1 (+https://github.com/m13253/dns-over-https)")
 	resp, err := c.httpClient.Do(req)
 	if err != nil {
@@ -95,23 +99,36 @@
 		reply.Rcode = dns.RcodeServerFailure
 		w.WriteMsg(reply)
 		c.httpTransport.CloseIdleConnections()
-		return
+		return &DNSRequest{
+			err: err,
+		}
 	}
-	if resp.StatusCode != 200 {
-		log.Printf("HTTP error: %s\n", resp.Status)
-		reply.Rcode = dns.RcodeServerFailure
-		contentType := resp.Header.Get("Content-Type")
+
+	return &DNSRequest{
+		response:          resp,
+		reply:             reply,
+		udpSize:           udpSize,
+		ednsClientAddress: ednsClientAddress,
+		ednsClientNetmask: ednsClientNetmask,
+	}
+}
+
+func (c *Client) parseResponseGoogle(w dns.ResponseWriter, r *dns.Msg, isTCP bool, req *DNSRequest) {
+	if req.response.StatusCode != 200 {
+		log.Printf("HTTP error: %s\n", req.response.Status)
+		req.reply.Rcode = dns.RcodeServerFailure
+		contentType := req.response.Header.Get("Content-Type")
 		if contentType != "application/json" && !strings.HasPrefix(contentType, "application/json;") {
-			w.WriteMsg(reply)
+			w.WriteMsg(req.reply)
 			return
 		}
 	}
 
-	body, err := ioutil.ReadAll(resp.Body)
+	body, err := ioutil.ReadAll(req.response.Body)
 	if err != nil {
 		log.Println(err)
-		reply.Rcode = dns.RcodeServerFailure
-		w.WriteMsg(reply)
+		req.reply.Rcode = dns.RcodeServerFailure
+		w.WriteMsg(req.reply)
 		return
 	}
 
@@ -119,8 +136,8 @@
 	err = json.Unmarshal(body, &respJSON)
 	if err != nil {
 		log.Println(err)
-		reply.Rcode = dns.RcodeServerFailure
-		w.WriteMsg(reply)
+		req.reply.Rcode = dns.RcodeServerFailure
+		w.WriteMsg(req.reply)
 		return
 	}
 
@@ -128,22 +145,22 @@
 		log.Printf("DNS error: %s\n", respJSON.Comment)
 	}
 
-	fullReply := jsonDNS.Unmarshal(reply, &respJSON, udpSize, ednsClientNetmask)
+	fullReply := jsonDNS.Unmarshal(req.reply, &respJSON, req.udpSize, req.ednsClientNetmask)
 	buf, err := fullReply.Pack()
 	if err != nil {
 		log.Println(err)
-		reply.Rcode = dns.RcodeServerFailure
-		w.WriteMsg(reply)
+		req.reply.Rcode = dns.RcodeServerFailure
+		w.WriteMsg(req.reply)
 		return
 	}
-	if !isTCP && len(buf) > int(udpSize) {
+	if !isTCP && len(buf) > int(req.udpSize) {
 		fullReply.Truncated = true
 		buf, err = fullReply.Pack()
 		if err != nil {
 			log.Println(err)
 			return
 		}
-		buf = buf[:udpSize]
+		buf = buf[:req.udpSize]
 	}
 	w.Write(buf)
 }

--- a/doh-client/ietf.go
+++ b/doh-client/ietf.go
@@ -30,6 +30,7 @@
 	"io/ioutil"
 	"log"
 	"math/rand"
+	"net"
 	"net/http"
 	"strconv"
 	"strings"
@@ -39,14 +40,16 @@
 	"github.com/miekg/dns"
 )
 
-func (c *Client) handlerFuncIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
+func (c *Client) generateRequestIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool) *DNSRequest {
 	reply := jsonDNS.PrepareReply(r)
 
 	if len(r.Question) != 1 {
 		log.Println("Number of questions is not 1")
 		reply.Rcode = dns.RcodeFormatError
 		w.WriteMsg(reply)
-		return
+		return &DNSRequest{
+			err: &dns.Error{},
+		}
 	}
 
 	question := &r.Question[0]
@@ -63,8 +66,6 @@
 		fmt.Printf("%s - - [%s] \"%s IN %s\"\n", w.RemoteAddr(), time.Now().Format("02/Jan/2006:15:04:05 -0700"), questionName, questionType)
 	}
 
-	requestID := r.Id
-	r.Id = 0
 	question.Name = questionName
 	opt := r.IsEdns0()
 	udpSize := uint16(512)
@@ -85,9 +86,10 @@
 			break
 		}
 	}
+	ednsClientAddress, ednsClientNetmask := net.IP(nil), uint8(255)
 	if edns0Subnet == nil {
 		ednsClientFamily := uint16(0)
-		ednsClientAddress, ednsClientNetmask := c.findClientIP(w, r)
+		ednsClientAddress, ednsClientNetmask = c.findClientIP(w, r)
 		if ednsClientAddress != nil {
 			if ipv4 := ednsClientAddress.To4(); ipv4 != nil {
 				ednsClientFamily = 1
@@ -105,15 +107,22 @@
 			edns0Subnet.Address = ednsClientAddress
 			opt.Option = append(opt.Option, edns0Subnet)
 		}
-	}
-
+	} else {
+		ednsClientAddress, ednsClientNetmask = edns0Subnet.Address, edns0Subnet.SourceNetmask
+	}
+
+	requestID := r.Id
+	r.Id = 0
 	requestBinary, err := r.Pack()
 	if err != nil {
 		log.Println(err)
 		reply.Rcode = dns.RcodeFormatError
 		w.WriteMsg(reply)
-		return
-	}
+		return &DNSRequest{
+			err: err,
+		}
+	}
+	r.Id = requestID
 	requestBase64 := base64.RawURLEncoding.EncodeToString(requestBinary)
 
 	numServers := len(c.conf.UpstreamIETF)
@@ -127,7 +136,9 @@
 			log.Println(err)
 			reply.Rcode = dns.RcodeServerFailure
 			w.WriteMsg(reply)
-			return
+			return &DNSRequest{
+				err: err,
+			}
 		}
 	} else {
 		req, err = http.NewRequest("POST", upstream, bytes.NewReader(requestBinary))
@@ -135,11 +146,13 @@
 			log.Println(err)
 			reply.Rcode = dns.RcodeServerFailure
 			w.WriteMsg(reply)
-			return
+			return &DNSRequest{
+				err: err,
+			}
 		}
 		req.Header.Set("Content-Type", "application/dns-udpwireformat")
 	}
-	req.Header.Set("Accept", "application/dns-udpwireformat")
+	req.Header.Set("Accept", "application/dns-udpwireformat, application/json")
 	req.Header.Set("User-Agent", "DNS-over-HTTPS/1.1 (+https://github.com/m13253/dns-over-https)")
 	resp, err := c.httpClient.Do(req)
 	if err != nil {
@@ -147,26 +160,39 @@
 		reply.Rcode = dns.RcodeServerFailure
 		w.WriteMsg(reply)
 		c.httpTransport.CloseIdleConnections()
+		return &DNSRequest{
+			err: err,
+		}
+	}
+
+	return &DNSRequest{
+		response:          resp,
+		reply:             reply,
+		udpSize:           udpSize,
+		ednsClientAddress: ednsClientAddress,
+		ednsClientNetmask: ednsClientNetmask,
+	}
+}
+
+func (c *Client) parseResponseIETF(w dns.ResponseWriter, r *dns.Msg, isTCP bool, req *DNSRequest) {
+	if req.response.StatusCode != 200 {
+		log.Printf("HTTP error: %s\n", req.response.Status)
+		req.reply.Rcode = dns.RcodeServerFailure
+		contentType := req.response.Header.Get("Content-Type")
+		if contentType != "application/dns-udpwireformat" && !strings.HasPrefix(contentType, "application/dns-udpwireformat;") {
+			w.WriteMsg(req.reply)
+			return
+		}
+	}
+
+	body, err := ioutil.ReadAll(req.response.Body)
+	if err != nil {
+		log.Println(err)
+		req.reply.Rcode = dns.RcodeServerFailure
+		w.WriteMsg(req.reply)
 		return
 	}
-	if resp.StatusCode != 200 {
-		log.Printf("HTTP error: %s\n", resp.Status)
-		reply.Rcode = dns.RcodeServerFailure
-		contentType := resp.Header.Get("Content-Type")
-		if contentType != "application/dns-udpwireformat" && !strings.HasPrefix(contentType, "application/dns-udpwireformat;") {
-			w.WriteMsg(reply)
-			return
-		}
-	}
-
-	body, err := ioutil.ReadAll(resp.Body)
-	if err != nil {
-		log.Println(err)
-		reply.Rcode = dns.RcodeServerFailure
-		w.WriteMsg(reply)
-		return
-	}
-	headerNow := resp.Header.Get("Date")
+	headerNow := req.response.Header.Get("Date")
 	now := time.Now().UTC()
 	if headerNow != "" {
 		if nowDate, err := time.Parse(http.TimeFormat, headerNow); err == nil {
@@ -175,7 +201,7 @@
 			log.Println(err)
 		}
 	}
-	headerLastModified := resp.Header.Get("Last-Modified")
+	headerLastModified := req.response.Header.Get("Last-Modified")
 	lastModified := now
 	if headerLastModified != "" {
 		if lastModifiedDate, err := time.Parse(http.TimeFormat, headerLastModified); err == nil {
@@ -193,12 +219,12 @@
 	err = fullReply.Unpack(body)
 	if err != nil {
 		log.Println(err)
-		reply.Rcode = dns.RcodeServerFailure
-		w.WriteMsg(reply)
+		req.reply.Rcode = dns.RcodeServerFailure
+		w.WriteMsg(req.reply)
 		return
 	}
 
-	fullReply.Id = requestID
+	fullReply.Id = r.Id
 	for _, rr := range fullReply.Answer {
 		_ = fixRecordTTL(rr, timeDelta)
 	}
@@ -215,18 +241,18 @@
 	buf, err := fullReply.Pack()
 	if err != nil {
 		log.Println(err)
-		reply.Rcode = dns.RcodeServerFailure
-		w.WriteMsg(reply)
+		req.reply.Rcode = dns.RcodeServerFailure
+		w.WriteMsg(req.reply)
 		return
 	}
-	if !isTCP && len(buf) > int(udpSize) {
+	if !isTCP && len(buf) > int(req.udpSize) {
 		fullReply.Truncated = true
 		buf, err = fullReply.Pack()
 		if err != nil {
 			log.Println(err)
 			return
 		}
-		buf = buf[:udpSize]
+		buf = buf[:req.udpSize]
 	}
 	w.Write(buf)
 }

--- a/doh-server/server.go
+++ b/doh-server/server.go
@@ -146,6 +146,8 @@
 		s.generateResponseGoogle(w, r, req)
 	} else if contentType == "application/dns-udpwireformat" {
 		s.generateResponseIETF(w, r, req)
+	} else {
+		panic("Unknown response Content-Type")
 	}
 }