Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Check origin header against allowlist for /zta endpoint" #2126

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 0 additions & 36 deletions ee/localserver/zta.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,12 @@ package localserver
import (
"log/slog"
"net/http"
"strings"

"github.com/kolide/launcher/pkg/traces"
)

var (
localserverZtaInfoKey = []byte("localserver_zta_info")

// allowlistedZtaOriginsLookup contains the complete list of origins that are permitted to access the /zta endpoint.
allowlistedZtaOriginsLookup = map[string]struct{}{
// Release extension
"chrome-extension://gejiddohjgogedgjnonbofjigllpkmbf": {},
"chrome-extension://khgocmkkpikpnmmkgmdnfckapcdkgfaf": {},
"chrome-extension://aeblfdkhhhdcdjpifhhbdiojplfjncoa": {},
"chrome-extension://dppgmdbiimibapkepcbdbmkaabgiofem": {},
"moz-extension://dfbae458-fb6f-4614-856e-094108a80852": {},
"moz-extension://25fc87fa-4d31-4fee-b5c1-c32a7844c063": {},
"moz-extension://d634138d-c276-4fc8-924b-40a0ea21d284": {},
// Development and internal builds
"chrome-extension://hjlinigoblmkhjejkmbegnoaljkphmgo": {},
"moz-extension://0a75d802-9aed-41e7-8daa-24c067386e82": {},
"chrome-extension://hiajhnnfoihkhlmfejoljaokdpgboiea": {},
"chrome-extension://kioanpobaefjdloichnjebbdafiloboa": {},
}
)

const (
safariWebExtensionScheme = "safari-web-extension://"
)

func (ls *localServer) requestZtaInfoHandler() http.Handler {
Expand All @@ -46,20 +24,6 @@ func (ls *localServer) requestZtaInfoHandlerFunc(w http.ResponseWriter, r *http.
return
}

// Validate origin. We cannot validate Safari extension origins against an allowlist
// because the UUID is generated randomly at extension startup, so for now, we also
// allow origins with scheme safari-web-extension.
requestOrigin := r.Header.Get("Origin")
if _, ok := allowlistedZtaOriginsLookup[requestOrigin]; !ok && !strings.HasPrefix(requestOrigin, safariWebExtensionScheme) {
escapedOrigin := strings.ReplaceAll(strings.ReplaceAll(requestOrigin, "\n", ""), "\r", "") // remove any newlines
ls.slogger.Log(r.Context(), slog.LevelInfo,
"received zta request with origin not in allowlist",
"req_origin", escapedOrigin,
)
w.WriteHeader(http.StatusForbidden)
return
}

ztaInfo, err := ls.knapsack.ZtaInfoStore().Get(localserverZtaInfoKey)
if err != nil {
traces.SetError(span, err)
Expand Down
109 changes: 17 additions & 92 deletions ee/localserver/zta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package localserver
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -41,59 +40,6 @@ func Test_requestZtaInfoHandler(t *testing.T) {

// Make a request to our handler
request := httptest.NewRequest(http.MethodGet, "/zta", nil)
request.Header.Set("origin", acceptableOrigin(t))
responseRecorder := httptest.NewRecorder()
ls.requestZtaInfoHandler().ServeHTTP(responseRecorder, request)

// Make sure response was successful and contains the data we expect
require.Equal(t, http.StatusOK, responseRecorder.Code)
require.Equal(t, "application/json", responseRecorder.Header().Get("Content-Type"))
require.Equal(t, testZtaInfo, responseRecorder.Body.Bytes())

k.AssertExpectations(t)
}

func acceptableOrigin(t *testing.T) string {
// Just grab the first origin available in our allowlist
acceptableOrigin := ""
for k := range allowlistedZtaOriginsLookup {
acceptableOrigin = k
break
}
if acceptableOrigin == "" {
t.Error("no acceptable origins found")
t.FailNow()
}

return acceptableOrigin
}

func Test_requestZtaInfoHandler_allowsAllSafariWebExtensionOrigins(t *testing.T) {
t.Parallel()

// Set up our ZTA store with some test data in it
slogger := multislogger.NewNopLogger()
ztaInfoStore, err := storageci.NewStore(t, slogger, storage.ZtaInfoStore.String())
require.NoError(t, err)
testZtaInfo, err := json.Marshal(map[string]string{
"some_test_data": "some_test_value",
})
require.NoError(t, err)
require.NoError(t, ztaInfoStore.Set(localserverZtaInfoKey, testZtaInfo))

// Set up the rest of our localserver dependencies
k := typesmocks.NewKnapsack(t)
k.On("KolideServerURL").Return("localserver")
k.On("Slogger").Return(slogger)
k.On("ZtaInfoStore").Return(ztaInfoStore)

// Set up localserver
ls, err := New(context.TODO(), k, nil)
require.NoError(t, err)

// Make a request to our handler
request := httptest.NewRequest(http.MethodGet, "/zta", nil)
request.Header.Set("origin", fmt.Sprintf("%sexample.com", safariWebExtensionScheme))
responseRecorder := httptest.NewRecorder()
ls.requestZtaInfoHandler().ServeHTTP(responseRecorder, request)

Expand All @@ -108,49 +54,30 @@ func Test_requestZtaInfoHandler_allowsAllSafariWebExtensionOrigins(t *testing.T)
func Test_requestZtaInfoHandler_badRequest(t *testing.T) {
t.Parallel()

reqOrigin := acceptableOrigin(t)

for _, tt := range []struct {
testCaseName string
httpMethod string
requestOrigin string
requestBody io.Reader
expectedResponseStatus int
testCaseName string
httpMethod string
requestBody io.Reader
}{
{
testCaseName: http.MethodPost,
httpMethod: http.MethodPost,
requestOrigin: reqOrigin,
requestBody: http.NoBody,
expectedResponseStatus: http.StatusMethodNotAllowed,
},
{
testCaseName: http.MethodPut,
httpMethod: http.MethodPut,
requestOrigin: reqOrigin,
requestBody: http.NoBody,
expectedResponseStatus: http.StatusMethodNotAllowed,
testCaseName: http.MethodPost,
httpMethod: http.MethodPost,
requestBody: http.NoBody,
},
{
testCaseName: http.MethodPatch,
httpMethod: http.MethodPatch,
requestOrigin: reqOrigin,
requestBody: http.NoBody,
expectedResponseStatus: http.StatusMethodNotAllowed,
testCaseName: http.MethodPut,
httpMethod: http.MethodPut,
requestBody: http.NoBody,
},
{
testCaseName: http.MethodDelete,
httpMethod: http.MethodDelete,
requestOrigin: reqOrigin,
requestBody: http.NoBody,
expectedResponseStatus: http.StatusMethodNotAllowed,
testCaseName: http.MethodPatch,
httpMethod: http.MethodPatch,
requestBody: http.NoBody,
},
{
testCaseName: "disallowed origin",
httpMethod: http.MethodGet,
requestOrigin: "https://example.com",
requestBody: http.NoBody,
expectedResponseStatus: http.StatusForbidden,
testCaseName: http.MethodDelete,
httpMethod: http.MethodDelete,
requestBody: http.NoBody,
},
} {
tt := tt
Expand All @@ -169,12 +96,11 @@ func Test_requestZtaInfoHandler_badRequest(t *testing.T) {

// Make a request to our handler
request := httptest.NewRequest(tt.httpMethod, "/zta", tt.requestBody)
request.Header.Set("origin", tt.requestOrigin)
responseRecorder := httptest.NewRecorder()
ls.requestZtaInfoHandler().ServeHTTP(responseRecorder, request)

// Make sure we got back an expected response status code (4xx-level)
require.Equal(t, tt.expectedResponseStatus, responseRecorder.Code)
// Make sure we got back a 405
require.Equal(t, http.StatusMethodNotAllowed, responseRecorder.Code)

k.AssertExpectations(t)
})
Expand All @@ -201,7 +127,6 @@ func Test_requestZtaInfoHandler_noDataAvailable(t *testing.T) {

// Make a request to our handler
request := httptest.NewRequest(http.MethodGet, "/zta", nil)
request.Header.Set("origin", acceptableOrigin(t))
responseRecorder := httptest.NewRecorder()
ls.requestZtaInfoHandler().ServeHTTP(responseRecorder, request)

Expand Down
Loading