Skip to content

Commit

Permalink
Fix device code verification URI unmarshalling (#556)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Feb 27, 2025
1 parent 51fa822 commit 2fbdd87
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 1 deletion.
2 changes: 1 addition & 1 deletion apps/internal/oauth/ops/accesstokens/accesstokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ type DeviceCodeResponse struct {

UserCode string `json:"user_code"`
DeviceCode string `json:"device_code"`
VerificationURL string `json:"verification_url"`
VerificationURL string `json:"verification_uri"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval"`
Message string `json:"message"`
Expand Down
79 changes: 79 additions & 0 deletions apps/public/public_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,85 @@ func TestAcquireTokenSilentWithoutAccount(t *testing.T) {
}
}

func TestAcquireTokenByDeviceCode(t *testing.T) {
accessToken := "*"
expected := accesstokens.DeviceCodeResult{
ClientID: "client-id",
DeviceCode: "device-code",
Message: "msg",
Interval: 1,
Scopes: tokenScope,
UserCode: "user-code",
VerificationURL: "https://device.code.local",
}
mockClient := mock.Client{}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody("http://localhost", "tenant")))
mockClient.AppendResponse(mock.WithBody([]byte(
fmt.Sprintf(
`{"device_code":%q,"expires_in":60,"interval":%d,"message":%q,"user_code":%q,"verification_uri":%q}`,
expected.DeviceCode,
expected.Interval,
expected.Message,
expected.UserCode,
expected.VerificationURL,
),
)))
mockClient.AppendResponse(
mock.WithBody(mock.GetAccessTokenBody(accessToken, "", "rt", "", 3600)),
mock.WithCallback(func(r *http.Request) {
if r.Method != http.MethodPost {
t.Fatalf("unexpected method %q", r.Method)
}
if err := r.ParseForm(); err != nil {
t.Fatal(err)
}
if v := r.Form.Get("client_id"); v != expected.ClientID {
t.Fatalf("unexpected client_id %q", v)
}
if v := r.Form.Get("device_code"); v != expected.DeviceCode {
t.Fatalf("unexpected device_code %q", v)
}
}),
)
client, err := New(expected.ClientID, WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
dc, err := client.AcquireTokenByDeviceCode(context.Background(), tokenScope)
if err != nil {
t.Fatal(err)
}
actual := dc.Result
if actual.ClientID != expected.ClientID {
t.Fatalf("unexpected client ID %q", actual.ClientID)
}
if actual.DeviceCode != expected.DeviceCode {
t.Fatalf("unexpected device code %q", actual.DeviceCode)
}
if !actual.ExpiresOn.After(time.Now()) {
t.Fatalf("expected a future expiration time but got %v", actual.ExpiresOn)
}
if actual.Interval != expected.Interval {
t.Fatalf("unexpected interval %d", actual.Interval)
}
if actual.Message != expected.Message {
t.Fatalf("unexpected message %q", actual.Message)
}
if actual.UserCode != expected.UserCode {
t.Fatalf("unexpected user code %q", actual.UserCode)
}
if actual.VerificationURL != expected.VerificationURL {
t.Fatalf("unexpected verification URL %q", actual.VerificationURL)
}
ar, err := dc.AuthenticationResult(context.Background())
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatalf("unexpected access token %q", ar.AccessToken)
}
}

func TestAcquireTokenWithTenantID(t *testing.T) {
accessToken := "*"
clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`))
Expand Down

0 comments on commit 2fbdd87

Please sign in to comment.