Skip to content

Commit

Permalink
feat: allow sign in with email (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
stonith404 authored Jan 19, 2025
1 parent e284e35 commit 06b90ed
Show file tree
Hide file tree
Showing 42 changed files with 422 additions and 145 deletions.
2 changes: 0 additions & 2 deletions backend/internal/bootstrap/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package bootstrap

import (
_ "github.com/golang-migrate/migrate/v4/source/file"
"github.com/stonith404/pocket-id/backend/internal/job"
"github.com/stonith404/pocket-id/backend/internal/service"
)

Expand All @@ -11,6 +10,5 @@ func Bootstrap() {
appConfigService := service.NewAppConfigService(db)

initApplicationImages()
job.RegisterJobs(db)
initRouter(db, appConfigService)
}
10 changes: 7 additions & 3 deletions backend/internal/bootstrap/router_bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,25 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
auditLogService := service.NewAuditLogService(db, appConfigService, emailService, geoLiteService)
jwtService := service.NewJwtService(appConfigService)
webauthnService := service.NewWebAuthnService(db, jwtService, auditLogService, appConfigService)
userService := service.NewUserService(db, jwtService, auditLogService)
userService := service.NewUserService(db, jwtService, auditLogService, emailService)
customClaimService := service.NewCustomClaimService(db)
oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService, customClaimService)
testService := service.NewTestService(db, appConfigService)
userGroupService := service.NewUserGroupService(db)
ldapService := service.NewLdapService(db, appConfigService, userService, userGroupService)

rateLimitMiddleware := middleware.NewRateLimitMiddleware()

// Setup global middleware
r.Use(middleware.NewCorsMiddleware().Add())
r.Use(middleware.NewErrorHandlerMiddleware().Add())
r.Use(middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60))
r.Use(rateLimitMiddleware.Add(rate.Every(time.Second), 60))
r.Use(middleware.NewJwtAuthMiddleware(jwtService, true).Add(false))

job.RegisterLdapJobs(ldapService, appConfigService)
job.RegisterDbCleanupJobs(db)

// Initialize middleware
// Initialize middleware for specific routes
jwtAuthMiddleware := middleware.NewJwtAuthMiddleware(jwtService, false)
fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware()

Expand Down
2 changes: 1 addition & 1 deletion backend/internal/common/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (e *MissingPermissionError) HttpStatusCode() int { return http.StatusForbid
type TooManyRequestsError struct{}

func (e *TooManyRequestsError) Error() string {
return "Too many requests. Please wait a while before trying again."
return "Too many requests"
}
func (e *TooManyRequestsError) HttpStatusCode() int { return http.StatusTooManyRequests }

Expand Down
21 changes: 19 additions & 2 deletions backend/internal/controller/user_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func NewUserController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.Jwt
group.POST("/users/:id/one-time-access-token", jwtAuthMiddleware.Add(true), uc.createOneTimeAccessTokenHandler)
group.POST("/one-time-access-token/:token", rateLimitMiddleware.Add(rate.Every(10*time.Second), 5), uc.exchangeOneTimeAccessTokenHandler)
group.POST("/one-time-access-token/setup", uc.getSetupAccessTokenHandler)
group.POST("/one-time-access-email", rateLimitMiddleware.Add(rate.Every(10*time.Minute), 3), uc.requestOneTimeAccessEmailHandler)
}

type UserController struct {
Expand Down Expand Up @@ -145,7 +146,7 @@ func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) {
return
}

token, err := uc.userService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt, c.ClientIP(), c.Request.UserAgent())
token, err := uc.userService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt)
if err != nil {
c.Error(err)
return
Expand All @@ -154,8 +155,24 @@ func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) {
c.JSON(http.StatusCreated, gin.H{"token": token})
}

func (uc *UserController) requestOneTimeAccessEmailHandler(c *gin.Context) {
var input dto.OneTimeAccessEmailDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
return
}

err := uc.userService.RequestOneTimeAccessEmail(input.Email, input.RedirectPath)
if err != nil {
c.Error(err)
return
}

c.Status(http.StatusNoContent)
}

func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
user, token, err := uc.userService.ExchangeOneTimeAccessToken(c.Param("token"))
user, token, err := uc.userService.ExchangeOneTimeAccessToken(c.Param("token"), c.ClientIP(), c.Request.UserAgent())
if err != nil {
c.Error(err)
return
Expand Down
3 changes: 2 additions & 1 deletion backend/internal/dto/app_config_dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ type AppConfigUpdateDto struct {
SessionDuration string `json:"sessionDuration" binding:"required"`
EmailsVerified string `json:"emailsVerified" binding:"required"`
AllowOwnAccountEdit string `json:"allowOwnAccountEdit" binding:"required"`
EmailEnabled string `json:"emailEnabled" binding:"required"`
SmtHost string `json:"smtpHost"`
SmtpPort string `json:"smtpPort"`
SmtpFrom string `json:"smtpFrom" binding:"omitempty,email"`
Expand All @@ -38,4 +37,6 @@ type AppConfigUpdateDto struct {
LdapAttributeGroupUniqueIdentifier string `json:"ldapAttributeGroupUniqueIdentifier"`
LdapAttributeGroupName string `json:"ldapAttributeGroupName"`
LdapAttributeAdminGroup string `json:"ldapAttributeAdminGroup"`
EmailOneTimeAccessEnabled string `json:"emailOneTimeAccessEnabled" binding:"required"`
EmailLoginNotificationEnabled string `json:"emailLoginNotificationEnabled" binding:"required"`
}
5 changes: 5 additions & 0 deletions backend/internal/dto/user_dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,8 @@ type OneTimeAccessTokenCreateDto struct {
UserID string `json:"userId" binding:"required"`
ExpiresAt time.Time `json:"expiresAt" binding:"required"`
}

type OneTimeAccessEmailDto struct {
Email string `json:"email" binding:"required,email"`
RedirectPath string `json:"redirectPath"`
}
2 changes: 1 addition & 1 deletion backend/internal/job/db_cleanup.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"time"
)

func RegisterJobs(db *gorm.DB) {
func RegisterDbCleanupJobs(db *gorm.DB) {
scheduler, err := gocron.NewScheduler()
if err != nil {
log.Fatalf("Failed to create a new scheduler: %s", err)
Expand Down
16 changes: 8 additions & 8 deletions backend/internal/middleware/rate_limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ func NewRateLimitMiddleware() *RateLimitMiddleware {
}

func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc {
// Map to store the rate limiters per IP
var clients = make(map[string]*client)
var mu sync.Mutex

// Start the cleanup routine
go cleanupClients()
go cleanupClients(&mu, clients)

return func(c *gin.Context) {
ip := c.ClientIP()
Expand All @@ -29,7 +33,7 @@ func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc {
return
}

limiter := getLimiter(ip, limit, burst)
limiter := getLimiter(ip, limit, burst, &mu, clients)
if !limiter.Allow() {
c.Error(&common.TooManyRequestsError{})
c.Abort()
Expand All @@ -45,12 +49,8 @@ type client struct {
lastSeen time.Time
}

// Map to store the rate limiters per IP
var clients = make(map[string]*client)
var mu sync.Mutex

// Cleanup routine to remove stale clients that haven't been seen for a while
func cleanupClients() {
func cleanupClients(mu *sync.Mutex, clients map[string]*client) {
for {
time.Sleep(time.Minute)
mu.Lock()
Expand All @@ -64,7 +64,7 @@ func cleanupClients() {
}

// getLimiter retrieves the rate limiter for a given IP address, creating one if it doesn't exist
func getLimiter(ip string, limit rate.Limit, burst int) *rate.Limiter {
func getLimiter(ip string, limit rate.Limit, burst int, mu *sync.Mutex, clients map[string]*client) *rate.Limiter {
mu.Lock()
defer mu.Unlock()

Expand Down
3 changes: 2 additions & 1 deletion backend/internal/model/app_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ type AppConfig struct {
LogoLightImageType AppConfigVariable
LogoDarkImageType AppConfigVariable
// Email
EmailEnabled AppConfigVariable
SmtpHost AppConfigVariable
SmtpPort AppConfigVariable
SmtpFrom AppConfigVariable
SmtpUser AppConfigVariable
SmtpPassword AppConfigVariable
SmtpTls AppConfigVariable
SmtpSkipCertVerify AppConfigVariable
EmailLoginNotificationEnabled AppConfigVariable
EmailOneTimeAccessEnabled AppConfigVariable
// LDAP
LdapEnabled AppConfigVariable
LdapUrl AppConfigVariable
Expand Down
25 changes: 19 additions & 6 deletions backend/internal/service/app_config_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,7 @@ var defaultDbConfig = model.AppConfig{
IsInternal: true,
DefaultValue: "svg",
},
// Email
EmailEnabled: model.AppConfigVariable{
Key: "emailEnabled",
Type: "bool",
DefaultValue: "false",
},
// Email
SmtpHost: model.AppConfigVariable{
Key: "smtpHost",
Type: "string",
Expand Down Expand Up @@ -109,6 +104,17 @@ var defaultDbConfig = model.AppConfig{
Type: "bool",
DefaultValue: "false",
},
EmailLoginNotificationEnabled: model.AppConfigVariable{
Key: "emailLoginNotificationEnabled",
Type: "bool",
DefaultValue: "false",
},
EmailOneTimeAccessEnabled: model.AppConfigVariable{
Key: "emailOneTimeAccessEnabled",
Type: "bool",
IsPublic: true,
DefaultValue: "false",
},
// LDAP
LdapEnabled: model.AppConfigVariable{
Key: "ldapEnabled",
Expand Down Expand Up @@ -182,6 +188,13 @@ func (s *AppConfigService) UpdateAppConfig(input dto.AppConfigUpdateDto) ([]mode
key := field.Tag.Get("json")
value := rv.FieldByName(field.Name).String()

// If the emailEnabled is set to false, disable the emailOneTimeAccessEnabled
if key == s.DbConfig.EmailOneTimeAccessEnabled.Key {
if rv.FieldByName("EmailEnabled").String() == "false" {
value = "false"
}
}

var appConfigVariable model.AppConfigVariable
if err := tx.First(&appConfigVariable, "key = ? AND is_internal = false", key).Error; err != nil {
tx.Rollback()
Expand Down
4 changes: 2 additions & 2 deletions backend/internal/service/audit_log_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID
return createdAuditLog
}

// If the user hasn't logged in from the same device before, send an email
if count <= 1 {
// If the user hasn't logged in from the same device before and email notifications are enabled, send an email
if s.appConfigService.DbConfig.EmailLoginNotificationEnabled.Value == "true" && count <= 1 {
go func() {
var user model.User
s.db.Where("id = ?", userID).First(&user)
Expand Down
23 changes: 14 additions & 9 deletions backend/internal/service/email_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package service
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/model"
Expand All @@ -16,8 +15,13 @@ import (
"net/smtp"
"net/textproto"
ttemplate "text/template"
"time"
)

var netDialer = &net.Dialer{
Timeout: 3 * time.Second,
}

type EmailService struct {
appConfigService *AppConfigService
db *gorm.DB
Expand Down Expand Up @@ -58,11 +62,6 @@ func (srv *EmailService) SendTestEmail(recipientUserId string) error {
}

func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.Template[V], tData *V) error {
// Check if SMTP settings are set
if srv.appConfigService.DbConfig.EmailEnabled.Value != "true" {
return errors.New("email not enabled")
}

data := &email.TemplateData[V]{
AppName: srv.appConfigService.DbConfig.AppName.Value,
LogoURL: common.EnvConfig.AppURL + "/api/application-configuration/logo",
Expand Down Expand Up @@ -112,11 +111,13 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
tlsConfig,
)
}
defer client.Quit()

if err != nil {
return fmt.Errorf("failed to connect to SMTP server: %w", err)
}

defer client.Close()

smtpUser := srv.appConfigService.DbConfig.SmtpUser.Value
smtpPassword := srv.appConfigService.DbConfig.SmtpPassword.Value

Expand All @@ -141,7 +142,11 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
}

func (srv *EmailService) connectToSmtpServerUsingImplicitTLS(serverAddr string, tlsConfig *tls.Config) (*smtp.Client, error) {
conn, err := tls.Dial("tcp", serverAddr, tlsConfig)
tlsDialer := &tls.Dialer{
NetDialer: netDialer,
Config: tlsConfig,
}
conn, err := tlsDialer.Dial("tcp", serverAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
}
Expand All @@ -156,7 +161,7 @@ func (srv *EmailService) connectToSmtpServerUsingImplicitTLS(serverAddr string,
}

func (srv *EmailService) connectToSmtpServerUsingStartTLS(serverAddr string, tlsConfig *tls.Config) (*smtp.Client, error) {
conn, err := net.Dial("tcp", serverAddr)
conn, err := netDialer.Dial("tcp", serverAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
}
Expand Down
15 changes: 13 additions & 2 deletions backend/internal/service/email_service_templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
/**
How to add new template:
- pick unique and descriptive template ${name} (for example "login-with-new-device")
- in backend/email-templates/ create "${name}_html.tmpl" and "${name}_text.tmpl"
- in backend/resources/email-templates/ create "${name}_html.tmpl" and "${name}_text.tmpl"
- create xxxxTemplate and xxxxTemplateData (for example NewLoginTemplate and NewLoginTemplateData)
- Path *must* be ${name}
- add xxxTemplate.Path to "emailTemplatePaths" at the end
Expand All @@ -27,6 +27,13 @@ var NewLoginTemplate = email.Template[NewLoginTemplateData]{
},
}

var OneTimeAccessTemplate = email.Template[OneTimeAccessTemplateData]{
Path: "one-time-access",
Title: func(data *email.TemplateData[OneTimeAccessTemplateData]) string {
return "One time access"
},
}

var TestTemplate = email.Template[struct{}]{
Path: "test",
Title: func(data *email.TemplateData[struct{}]) string {
Expand All @@ -42,5 +49,9 @@ type NewLoginTemplateData struct {
DateTime time.Time
}

type OneTimeAccessTemplateData = struct {
Link string
}

// this is list of all template paths used for preloading templates
var emailTemplatesPaths = []string{NewLoginTemplate.Path, TestTemplate.Path}
var emailTemplatesPaths = []string{NewLoginTemplate.Path, OneTimeAccessTemplate.Path, TestTemplate.Path}
Loading

0 comments on commit 06b90ed

Please sign in to comment.