Skip to content

Commit

Permalink
Fix interface conversion error: getting nil value from ThreadLocal[T]…
Browse files Browse the repository at this point in the history
…, where T is interface type
  • Loading branch information
timandy committed Dec 5, 2023
1 parent 3f1cce0 commit 37a5553
Show file tree
Hide file tree
Showing 7 changed files with 415 additions and 31 deletions.
200 changes: 200 additions & 0 deletions api_thread_local_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,52 @@ func TestNewThreadLocal_Concurrency(t *testing.T) {
task.Get()
}

func TestNewThreadLocal_Interface(t *testing.T) {
tls := NewThreadLocal[Cloneable]()
tls2 := NewThreadLocal[Cloneable]()
//
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
//
tls.Set(nil)
tls2.Set(nil)
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
//
tls.Set(&personCloneable{Id: 1, Name: "Hello"})
tls2.Set(&personCloneable{Id: 1, Name: "Hello"})
assert.NotNil(t, tls.Get())
assert.NotNil(t, tls2.Get())
//
tls.Remove()
tls2.Remove()
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
}

func TestNewThreadLocal_Pointer(t *testing.T) {
tls := NewThreadLocal[*personCloneable]()
tls2 := NewThreadLocal[*personCloneable]()
//
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
//
tls.Set(nil)
tls2.Set(nil)
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
//
tls.Set(&personCloneable{Id: 1, Name: "Hello"})
tls2.Set(&personCloneable{Id: 1, Name: "Hello"})
assert.NotNil(t, tls.Get())
assert.NotNil(t, tls2.Get())
//
tls.Remove()
tls2.Remove()
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
}

//===

func TestNewThreadLocalWithInitial_Single(t *testing.T) {
Expand Down Expand Up @@ -189,6 +235,60 @@ func TestNewThreadLocalWithInitial_Concurrency(t *testing.T) {
task.Get()
}

func TestNewThreadLocalWithInitial_Interface(t *testing.T) {
tls := NewThreadLocalWithInitial[Cloneable](func() Cloneable {
return nil
})
tls2 := NewThreadLocalWithInitial[Cloneable](func() Cloneable {
return nil
})
//
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
//
tls.Set(nil)
tls2.Set(nil)
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
//
tls.Set(&personCloneable{Id: 1, Name: "Hello"})
tls2.Set(&personCloneable{Id: 1, Name: "Hello"})
assert.NotNil(t, tls.Get())
assert.NotNil(t, tls2.Get())
//
tls.Remove()
tls2.Remove()
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
}

func TestNewThreadLocalWithInitial_Pointer(t *testing.T) {
tls := NewThreadLocalWithInitial[*personCloneable](func() *personCloneable {
return nil
})
tls2 := NewThreadLocalWithInitial[*personCloneable](func() *personCloneable {
return nil
})
//
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
//
tls.Set(nil)
tls2.Set(nil)
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
//
tls.Set(&personCloneable{Id: 1, Name: "Hello"})
tls2.Set(&personCloneable{Id: 1, Name: "Hello"})
assert.NotNil(t, tls.Get())
assert.NotNil(t, tls2.Get())
//
tls.Remove()
tls2.Remove()
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
}

//===

func TestNewInheritableThreadLocal_Single(t *testing.T) {
Expand Down Expand Up @@ -267,6 +367,52 @@ func TestNewInheritableThreadLocal_Concurrency(t *testing.T) {
task.Get()
}

func TestNewInheritableThreadLocal_Interface(t *testing.T) {
tls := NewInheritableThreadLocal[Cloneable]()
tls2 := NewInheritableThreadLocal[Cloneable]()
//
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
//
tls.Set(nil)
tls2.Set(nil)
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
//
tls.Set(&personCloneable{Id: 1, Name: "Hello"})
tls2.Set(&personCloneable{Id: 1, Name: "Hello"})
assert.NotNil(t, tls.Get())
assert.NotNil(t, tls2.Get())
//
tls.Remove()
tls2.Remove()
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
}

func TestNewInheritableThreadLocal_Pointer(t *testing.T) {
tls := NewInheritableThreadLocal[*personCloneable]()
tls2 := NewInheritableThreadLocal[*personCloneable]()
//
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
//
tls.Set(nil)
tls2.Set(nil)
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
//
tls.Set(&personCloneable{Id: 1, Name: "Hello"})
tls2.Set(&personCloneable{Id: 1, Name: "Hello"})
assert.NotNil(t, tls.Get())
assert.NotNil(t, tls2.Get())
//
tls.Remove()
tls2.Remove()
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
}

//===

func TestNewInheritableThreadLocalWithInitial_Single(t *testing.T) {
Expand Down Expand Up @@ -355,6 +501,60 @@ func TestNewInheritableThreadLocalWithInitial_Concurrency(t *testing.T) {
task.Get()
}

func TestNewInheritableThreadLocalWithInitial_Interface(t *testing.T) {
tls := NewInheritableThreadLocalWithInitial[Cloneable](func() Cloneable {
return nil
})
tls2 := NewInheritableThreadLocalWithInitial[Cloneable](func() Cloneable {
return nil
})
//
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
//
tls.Set(nil)
tls2.Set(nil)
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
//
tls.Set(&personCloneable{Id: 1, Name: "Hello"})
tls2.Set(&personCloneable{Id: 1, Name: "Hello"})
assert.NotNil(t, tls.Get())
assert.NotNil(t, tls2.Get())
//
tls.Remove()
tls2.Remove()
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
}

func TestNewInheritableThreadLocalWithInitial_Pointer(t *testing.T) {
tls := NewInheritableThreadLocalWithInitial[*personCloneable](func() *personCloneable {
return nil
})
tls2 := NewInheritableThreadLocalWithInitial[*personCloneable](func() *personCloneable {
return nil
})
//
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
//
tls.Set(nil)
tls2.Set(nil)
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
//
tls.Set(&personCloneable{Id: 1, Name: "Hello"})
tls2.Set(&personCloneable{Id: 1, Name: "Hello"})
assert.NotNil(t, tls.Get())
assert.NotNil(t, tls2.Get())
//
tls.Remove()
tls2.Remove()
assert.Nil(t, tls.Get())
assert.Nil(t, tls2.Get())
}

//===

// BenchmarkThreadLocal-8 13636471 94.17 ns/op 7 B/op 0 allocs/op
Expand Down
8 changes: 4 additions & 4 deletions thread_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func (tls *threadLocal[T]) Get() T {
if mp != nil {
v := mp.get(tls.index)
if v != unset {
return v.(T)
return entryValue[T](v)
}
}
return tls.setInitialValue(t)
Expand All @@ -34,7 +34,7 @@ func (tls *threadLocal[T]) Set(value T) {
t := currentThread(true)
mp := tls.getMap(t)
if mp != nil {
mp.set(tls.index, value)
mp.set(tls.index, entry(value))
} else {
tls.createMap(t, value)
}
Expand All @@ -57,15 +57,15 @@ func (tls *threadLocal[T]) getMap(t *thread) *threadLocalMap {

func (tls *threadLocal[T]) createMap(t *thread, firstValue T) {
mp := &threadLocalMap{}
mp.set(tls.index, firstValue)
mp.set(tls.index, entry(firstValue))
t.threadLocals = mp
}

func (tls *threadLocal[T]) setInitialValue(t *thread) T {
value := tls.initialValue()
mp := tls.getMap(t)
if mp != nil {
mp.set(tls.index, value)
mp.set(tls.index, entry(value))
} else {
tls.createMap(t, value)
}
Expand Down
8 changes: 4 additions & 4 deletions thread_local_inheritable.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func (tls *inheritableThreadLocal[T]) Get() T {
if mp != nil {
v := mp.get(tls.index)
if v != unset {
return v.(T)
return entryValue[T](v)
}
}
return tls.setInitialValue(t)
Expand All @@ -34,7 +34,7 @@ func (tls *inheritableThreadLocal[T]) Set(value T) {
t := currentThread(true)
mp := tls.getMap(t)
if mp != nil {
mp.set(tls.index, value)
mp.set(tls.index, entry(value))
} else {
tls.createMap(t, value)
}
Expand All @@ -57,15 +57,15 @@ func (tls *inheritableThreadLocal[T]) getMap(t *thread) *threadLocalMap {

func (tls *inheritableThreadLocal[T]) createMap(t *thread, firstValue T) {
mp := &threadLocalMap{}
mp.set(tls.index, firstValue)
mp.set(tls.index, entry(firstValue))
t.inheritableThreadLocals = mp
}

func (tls *inheritableThreadLocal[T]) setInitialValue(t *thread) T {
value := tls.initialValue()
mp := tls.getMap(t)
if mp != nil {
mp.set(tls.index, value)
mp.set(tls.index, entry(value))
} else {
tls.createMap(t, value)
}
Expand Down
18 changes: 9 additions & 9 deletions thread_local_map.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
package routine

var unset any = &object{}
var unset entry = &object{}

type object struct {
none bool //nolint:unused
}

type threadLocalMap struct {
table []any
table []entry
}

func (mp *threadLocalMap) get(index int) any {
func (mp *threadLocalMap) get(index int) entry {
lookup := mp.table
if index < len(lookup) {
return lookup[index]
}
return unset
}

func (mp *threadLocalMap) set(index int, value any) {
func (mp *threadLocalMap) set(index int, value entry) {
lookup := mp.table
if index < len(lookup) {
lookup[index] = value
Expand All @@ -34,7 +34,7 @@ func (mp *threadLocalMap) remove(index int) {
}
}

func (mp *threadLocalMap) expandAndSet(index int, value any) {
func (mp *threadLocalMap) expandAndSet(index int, value entry) {
oldArray := mp.table
oldCapacity := len(oldArray)
newCapacity := index
Expand All @@ -45,7 +45,7 @@ func (mp *threadLocalMap) expandAndSet(index int, value any) {
newCapacity |= newCapacity >> 16
newCapacity++

newArray := make([]any, newCapacity)
newArray := make([]entry, newCapacity)
copy(newArray, oldArray)
fill(newArray, oldCapacity, newCapacity, unset)
newArray[index] = value
Expand All @@ -65,11 +65,11 @@ func createInheritedMap() *threadLocalMap {
if lookup == nil {
return nil
}
table := make([]any, len(lookup))
table := make([]entry, len(lookup))
copy(table, lookup)
for i := 0; i < len(table); i++ {
if c, ok := table[i].(Cloneable); ok {
table[i] = c.Clone()
if c, ok := entryAssert[Cloneable](table[i]); ok {
table[i] = entry(c.Clone())
}
}
return &threadLocalMap{table: table}
Expand Down
16 changes: 16 additions & 0 deletions thread_local_map_entry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package routine

type entry any

func entryValue[T any](e entry) T {
if e == nil {
var defaultValue T
return defaultValue
}
return e.(T)
}

func entryAssert[T any](e entry) (T, bool) {
v, ok := e.(T)
return v, ok
}
Loading

0 comments on commit 37a5553

Please sign in to comment.