Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow sign in with email #100

Merged
merged 11 commits into from
Jan 19, 2025
2 changes: 0 additions & 2 deletions backend/internal/bootstrap/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package bootstrap

import (
_ "github.com/golang-migrate/migrate/v4/source/file"
"github.com/stonith404/pocket-id/backend/internal/job"
"github.com/stonith404/pocket-id/backend/internal/service"
)

Expand All @@ -11,6 +10,5 @@ func Bootstrap() {
appConfigService := service.NewAppConfigService(db)

initApplicationImages()
job.RegisterJobs(db)
initRouter(db, appConfigService)
}
10 changes: 7 additions & 3 deletions backend/internal/bootstrap/router_bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,25 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
auditLogService := service.NewAuditLogService(db, appConfigService, emailService, geoLiteService)
jwtService := service.NewJwtService(appConfigService)
webauthnService := service.NewWebAuthnService(db, jwtService, auditLogService, appConfigService)
userService := service.NewUserService(db, jwtService, auditLogService)
userService := service.NewUserService(db, jwtService, auditLogService, emailService)
customClaimService := service.NewCustomClaimService(db)
oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService, customClaimService)
testService := service.NewTestService(db, appConfigService)
userGroupService := service.NewUserGroupService(db)
ldapService := service.NewLdapService(db, appConfigService, userService, userGroupService)

rateLimitMiddleware := middleware.NewRateLimitMiddleware()

// Setup global middleware
r.Use(middleware.NewCorsMiddleware().Add())
r.Use(middleware.NewErrorHandlerMiddleware().Add())
r.Use(middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60))
r.Use(rateLimitMiddleware.Add(rate.Every(time.Second), 60))
r.Use(middleware.NewJwtAuthMiddleware(jwtService, true).Add(false))

job.RegisterLdapJobs(ldapService, appConfigService)
job.RegisterDbCleanupJobs(db)

// Initialize middleware
// Initialize middleware for specific routes
jwtAuthMiddleware := middleware.NewJwtAuthMiddleware(jwtService, false)
fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware()

Expand Down
2 changes: 1 addition & 1 deletion backend/internal/common/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (e *MissingPermissionError) HttpStatusCode() int { return http.StatusForbid
type TooManyRequestsError struct{}

func (e *TooManyRequestsError) Error() string {
return "Too many requests. Please wait a while before trying again."
return "Too many requests"
}
func (e *TooManyRequestsError) HttpStatusCode() int { return http.StatusTooManyRequests }

Expand Down
21 changes: 19 additions & 2 deletions backend/internal/controller/user_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func NewUserController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.Jwt
group.POST("/users/:id/one-time-access-token", jwtAuthMiddleware.Add(true), uc.createOneTimeAccessTokenHandler)
group.POST("/one-time-access-token/:token", rateLimitMiddleware.Add(rate.Every(10*time.Second), 5), uc.exchangeOneTimeAccessTokenHandler)
group.POST("/one-time-access-token/setup", uc.getSetupAccessTokenHandler)
group.POST("/one-time-access-email", rateLimitMiddleware.Add(rate.Every(10*time.Minute), 3), uc.requestOneTimeAccessEmailHandler)
}

type UserController struct {
Expand Down Expand Up @@ -145,7 +146,7 @@ func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) {
return
}

token, err := uc.userService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt, c.ClientIP(), c.Request.UserAgent())
token, err := uc.userService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt)
if err != nil {
c.Error(err)
return
Expand All @@ -154,8 +155,24 @@ func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) {
c.JSON(http.StatusCreated, gin.H{"token": token})
}

func (uc *UserController) requestOneTimeAccessEmailHandler(c *gin.Context) {
var input dto.OneTimeAccessEmailDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
return
}

err := uc.userService.RequestOneTimeAccessEmail(input.Email, input.RedirectPath)
if err != nil {
c.Error(err)
return
}

c.Status(http.StatusNoContent)
}

func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
user, token, err := uc.userService.ExchangeOneTimeAccessToken(c.Param("token"))
user, token, err := uc.userService.ExchangeOneTimeAccessToken(c.Param("token"), c.ClientIP(), c.Request.UserAgent())
if err != nil {
c.Error(err)
return
Expand Down
3 changes: 2 additions & 1 deletion backend/internal/dto/app_config_dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ type AppConfigUpdateDto struct {
SessionDuration string `json:"sessionDuration" binding:"required"`
EmailsVerified string `json:"emailsVerified" binding:"required"`
AllowOwnAccountEdit string `json:"allowOwnAccountEdit" binding:"required"`
EmailEnabled string `json:"emailEnabled" binding:"required"`
SmtHost string `json:"smtpHost"`
SmtpPort string `json:"smtpPort"`
SmtpFrom string `json:"smtpFrom" binding:"omitempty,email"`
Expand All @@ -38,4 +37,6 @@ type AppConfigUpdateDto struct {
LdapAttributeGroupUniqueIdentifier string `json:"ldapAttributeGroupUniqueIdentifier"`
LdapAttributeGroupName string `json:"ldapAttributeGroupName"`
LdapAttributeAdminGroup string `json:"ldapAttributeAdminGroup"`
EmailOneTimeAccessEnabled string `json:"emailOneTimeAccessEnabled" binding:"required"`
EmailLoginNotificationEnabled string `json:"emailLoginNotificationEnabled" binding:"required"`
}
5 changes: 5 additions & 0 deletions backend/internal/dto/user_dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,8 @@ type OneTimeAccessTokenCreateDto struct {
UserID string `json:"userId" binding:"required"`
ExpiresAt time.Time `json:"expiresAt" binding:"required"`
}

type OneTimeAccessEmailDto struct {
Email string `json:"email" binding:"required,email"`
RedirectPath string `json:"redirectPath"`
}
2 changes: 1 addition & 1 deletion backend/internal/job/db_cleanup.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"time"
)

func RegisterJobs(db *gorm.DB) {
func RegisterDbCleanupJobs(db *gorm.DB) {
scheduler, err := gocron.NewScheduler()
if err != nil {
log.Fatalf("Failed to create a new scheduler: %s", err)
Expand Down
16 changes: 8 additions & 8 deletions backend/internal/middleware/rate_limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ func NewRateLimitMiddleware() *RateLimitMiddleware {
}

func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc {
// Map to store the rate limiters per IP
var clients = make(map[string]*client)
var mu sync.Mutex

// Start the cleanup routine
go cleanupClients()
go cleanupClients(&mu, clients)

return func(c *gin.Context) {
ip := c.ClientIP()
Expand All @@ -29,7 +33,7 @@ func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc {
return
}

limiter := getLimiter(ip, limit, burst)
limiter := getLimiter(ip, limit, burst, &mu, clients)
if !limiter.Allow() {
c.Error(&common.TooManyRequestsError{})
c.Abort()
Expand All @@ -45,12 +49,8 @@ type client struct {
lastSeen time.Time
}

// Map to store the rate limiters per IP
var clients = make(map[string]*client)
var mu sync.Mutex

// Cleanup routine to remove stale clients that haven't been seen for a while
func cleanupClients() {
func cleanupClients(mu *sync.Mutex, clients map[string]*client) {
for {
time.Sleep(time.Minute)
mu.Lock()
Expand All @@ -64,7 +64,7 @@ func cleanupClients() {
}

// getLimiter retrieves the rate limiter for a given IP address, creating one if it doesn't exist
func getLimiter(ip string, limit rate.Limit, burst int) *rate.Limiter {
func getLimiter(ip string, limit rate.Limit, burst int, mu *sync.Mutex, clients map[string]*client) *rate.Limiter {
mu.Lock()
defer mu.Unlock()

Expand Down
3 changes: 2 additions & 1 deletion backend/internal/model/app_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ type AppConfig struct {
LogoLightImageType AppConfigVariable
LogoDarkImageType AppConfigVariable
// Email
EmailEnabled AppConfigVariable
SmtpHost AppConfigVariable
SmtpPort AppConfigVariable
SmtpFrom AppConfigVariable
SmtpUser AppConfigVariable
SmtpPassword AppConfigVariable
SmtpTls AppConfigVariable
SmtpSkipCertVerify AppConfigVariable
EmailLoginNotificationEnabled AppConfigVariable
EmailOneTimeAccessEnabled AppConfigVariable
// LDAP
LdapEnabled AppConfigVariable
LdapUrl AppConfigVariable
Expand Down
25 changes: 19 additions & 6 deletions backend/internal/service/app_config_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,7 @@ var defaultDbConfig = model.AppConfig{
IsInternal: true,
DefaultValue: "svg",
},
// Email
EmailEnabled: model.AppConfigVariable{
Key: "emailEnabled",
Type: "bool",
DefaultValue: "false",
},
// Email
SmtpHost: model.AppConfigVariable{
Key: "smtpHost",
Type: "string",
Expand Down Expand Up @@ -109,6 +104,17 @@ var defaultDbConfig = model.AppConfig{
Type: "bool",
DefaultValue: "false",
},
EmailLoginNotificationEnabled: model.AppConfigVariable{
Key: "emailLoginNotificationEnabled",
Type: "bool",
DefaultValue: "false",
},
EmailOneTimeAccessEnabled: model.AppConfigVariable{
Key: "emailOneTimeAccessEnabled",
Type: "bool",
IsPublic: true,
DefaultValue: "false",
},
// LDAP
LdapEnabled: model.AppConfigVariable{
Key: "ldapEnabled",
Expand Down Expand Up @@ -182,6 +188,13 @@ func (s *AppConfigService) UpdateAppConfig(input dto.AppConfigUpdateDto) ([]mode
key := field.Tag.Get("json")
value := rv.FieldByName(field.Name).String()

// If the emailEnabled is set to false, disable the emailOneTimeAccessEnabled
if key == s.DbConfig.EmailOneTimeAccessEnabled.Key {
if rv.FieldByName("EmailEnabled").String() == "false" {
value = "false"
}
}

var appConfigVariable model.AppConfigVariable
if err := tx.First(&appConfigVariable, "key = ? AND is_internal = false", key).Error; err != nil {
tx.Rollback()
Expand Down
4 changes: 2 additions & 2 deletions backend/internal/service/audit_log_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID
return createdAuditLog
}

// If the user hasn't logged in from the same device before, send an email
if count <= 1 {
// If the user hasn't logged in from the same device before and email notifications are enabled, send an email
if s.appConfigService.DbConfig.EmailLoginNotificationEnabled.Value == "true" && count <= 1 {
go func() {
var user model.User
s.db.Where("id = ?", userID).First(&user)
Expand Down
23 changes: 14 additions & 9 deletions backend/internal/service/email_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package service
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/model"
Expand All @@ -16,8 +15,13 @@ import (
"net/smtp"
"net/textproto"
ttemplate "text/template"
"time"
)

var netDialer = &net.Dialer{
Timeout: 3 * time.Second,
}

type EmailService struct {
appConfigService *AppConfigService
db *gorm.DB
Expand Down Expand Up @@ -58,11 +62,6 @@ func (srv *EmailService) SendTestEmail(recipientUserId string) error {
}

func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.Template[V], tData *V) error {
// Check if SMTP settings are set
if srv.appConfigService.DbConfig.EmailEnabled.Value != "true" {
return errors.New("email not enabled")
}

data := &email.TemplateData[V]{
AppName: srv.appConfigService.DbConfig.AppName.Value,
LogoURL: common.EnvConfig.AppURL + "/api/application-configuration/logo",
Expand Down Expand Up @@ -112,11 +111,13 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
tlsConfig,
)
}
defer client.Quit()

if err != nil {
return fmt.Errorf("failed to connect to SMTP server: %w", err)
}

defer client.Close()

smtpUser := srv.appConfigService.DbConfig.SmtpUser.Value
smtpPassword := srv.appConfigService.DbConfig.SmtpPassword.Value

Expand All @@ -141,7 +142,11 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
}

func (srv *EmailService) connectToSmtpServerUsingImplicitTLS(serverAddr string, tlsConfig *tls.Config) (*smtp.Client, error) {
conn, err := tls.Dial("tcp", serverAddr, tlsConfig)
tlsDialer := &tls.Dialer{
NetDialer: netDialer,
Config: tlsConfig,
}
conn, err := tlsDialer.Dial("tcp", serverAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
}
Expand All @@ -156,7 +161,7 @@ func (srv *EmailService) connectToSmtpServerUsingImplicitTLS(serverAddr string,
}

func (srv *EmailService) connectToSmtpServerUsingStartTLS(serverAddr string, tlsConfig *tls.Config) (*smtp.Client, error) {
conn, err := net.Dial("tcp", serverAddr)
conn, err := netDialer.Dial("tcp", serverAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
}
Expand Down
15 changes: 13 additions & 2 deletions backend/internal/service/email_service_templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
/**
How to add new template:
- pick unique and descriptive template ${name} (for example "login-with-new-device")
- in backend/email-templates/ create "${name}_html.tmpl" and "${name}_text.tmpl"
- in backend/resources/email-templates/ create "${name}_html.tmpl" and "${name}_text.tmpl"
- create xxxxTemplate and xxxxTemplateData (for example NewLoginTemplate and NewLoginTemplateData)
- Path *must* be ${name}
- add xxxTemplate.Path to "emailTemplatePaths" at the end
Expand All @@ -27,6 +27,13 @@ var NewLoginTemplate = email.Template[NewLoginTemplateData]{
},
}

var OneTimeAccessTemplate = email.Template[OneTimeAccessTemplateData]{
Path: "one-time-access",
Title: func(data *email.TemplateData[OneTimeAccessTemplateData]) string {
return "One time access"
},
}

var TestTemplate = email.Template[struct{}]{
Path: "test",
Title: func(data *email.TemplateData[struct{}]) string {
Expand All @@ -42,5 +49,9 @@ type NewLoginTemplateData struct {
DateTime time.Time
}

type OneTimeAccessTemplateData = struct {
Link string
}

// this is list of all template paths used for preloading templates
var emailTemplatesPaths = []string{NewLoginTemplate.Path, TestTemplate.Path}
var emailTemplatesPaths = []string{NewLoginTemplate.Path, OneTimeAccessTemplate.Path, TestTemplate.Path}
Loading
Loading