Skip to content

Commit

Permalink
Add TLS intercepting CONNECT handler
Browse files Browse the repository at this point in the history
  • Loading branch information
dvob committed Jan 23, 2023
1 parent b7a6bf2 commit 3394122
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 2 deletions.
98 changes: 98 additions & 0 deletions connect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package main

import (
"crypto/tls"
"log"
"net"
"net/http"
)

type getCertFn func(hostname string) (*tls.Config, error)

type interceptHandler struct {
listener channelListener
server *http.Server
getCert getCertFn
}

func newInterceptHandler(getCert getCertFn, innerHandler http.HandlerFunc) *interceptHandler {
server := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.URL.Scheme = "https"
r.URL.Host = r.Host
innerHandler(w, r)
}),
}

listener := channelListener(make(chan net.Conn))

go func() {
// returns always a non-nil error if the server is not closed/shtudown
_ = server.Serve(listener)
}()

return &interceptHandler{
listener: listener,
server: server,
getCert: getCert,
}
}

func (i *interceptHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
log.Printf("split host port failed '%s': %s", r.Host, err)
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}

tlsConfig, err := i.getCert(host)
if err != nil {
log.Println("failed to obtain tls config:", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}

hj, ok := w.(http.Hijacker)
if !ok {
log.Print("hijack of connection failed")
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}

w.WriteHeader(http.StatusOK)

clientConn, _, err := hj.Hijack()
if err != nil {
log.Println("hijack failed:", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}

tlsConn := tls.Server(clientConn, tlsConfig)
i.handleConnection(tlsConn)
}

func (i *interceptHandler) handleConnection(c net.Conn) {
i.listener <- c
}

func (i *interceptHandler) close() {
i.server.Close()
}

// channelListener allows to send connection into a listener through a channel
type channelListener chan net.Conn

func (cl channelListener) Accept() (net.Conn, error) {
return <-cl, nil
}

func (cl channelListener) Addr() net.Addr {
return nil
}

func (cl channelListener) Close() error {
close(cl)
return nil
}
16 changes: 14 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,27 @@ func main() {
return
}

certGen, err := newCertGenerator(caCertFile, caKeyFile)
if err != nil {
log.Print(err)
os.Exit(1)
}

connectHandler := newInterceptHandler(certGen.Get, logRequest(forward))
if err != nil {
log.Print(err)
os.Exit(1)
}

handler := logRequest(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "CONNECT" {
tunnel(w, r)
connectHandler.ServeHTTP(w, r)
} else {
forward(w, r)
}
})

err := http.ListenAndServe(":8080", http.HandlerFunc(handler))
err = http.ListenAndServe(":8080", http.HandlerFunc(handler))
if err != nil {
log.Print(err)
os.Exit(1)
Expand Down

0 comments on commit 3394122

Please sign in to comment.