Skip to content

Commit

Permalink
Merge pull request #25 from pkoukk/enhance_safe-get-encoding
Browse files Browse the repository at this point in the history
enhancement: thread safe for get encoding
  • Loading branch information
pkoukk authored Jun 16, 2023
2 parents 9fda5e9 + 7f4fad9 commit 77200b6
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 53 deletions.
123 changes: 72 additions & 51 deletions encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tiktoken

import (
"errors"
"sync"
)

const ENDOFTEXT string = "<|endoftext|>"
Expand All @@ -10,52 +11,73 @@ const FIM_MIDDLE string = "<|fim_middle|>"
const FIM_SUFFIX string = "<|fim_suffix|>"
const ENDOFPROMPT string = "<|endofprompt|>"

const (
MODEL_CL100K_BASE string = "cl100k_base"
MODEL_P50K_BASE string = "p50k_base"
MODEL_P50K_EDIT string = "p50k_edit"
MODEL_R50K_BASE string = "r50k_base"
)

var MODEL_TO_ENCODING = map[string]string{
// chat
"gpt-4": "cl100k_base",
"gpt-3.5-turbo": "cl100k_base",
"gpt-4": MODEL_CL100K_BASE,
"gpt-3.5-turbo": MODEL_CL100K_BASE,
// text
"text-davinci-003": "p50k_base",
"text-davinci-002": "p50k_base",
"text-davinci-001": "r50k_base",
"text-curie-001": "r50k_base",
"text-babbage-001": "r50k_base",
"text-ada-001": "r50k_base",
"davinci": "r50k_base",
"curie": "r50k_base",
"babbage": "r50k_base",
"ada": "r50k_base",
"text-davinci-003": MODEL_P50K_BASE,
"text-davinci-002": MODEL_P50K_BASE,
"text-davinci-001": MODEL_R50K_BASE,
"text-curie-001": MODEL_R50K_BASE,
"text-babbage-001": MODEL_R50K_BASE,
"text-ada-001": MODEL_R50K_BASE,
"davinci": MODEL_R50K_BASE,
"curie": MODEL_R50K_BASE,
"babbage": MODEL_R50K_BASE,
"ada": MODEL_R50K_BASE,
// code
"code-davinci-002": "p50k_base",
"code-davinci-001": "p50k_base",
"code-cushman-002": "p50k_base",
"code-cushman-001": "p50k_base",
"davinci-codex": "p50k_base",
"cushman-codex": "p50k_base",
"code-davinci-002": MODEL_P50K_BASE,
"code-davinci-001": MODEL_P50K_BASE,
"code-cushman-002": MODEL_P50K_BASE,
"code-cushman-001": MODEL_P50K_BASE,
"davinci-codex": MODEL_P50K_BASE,
"cushman-codex": MODEL_P50K_BASE,
// edit
"text-davinci-edit-001": "p50k_edit",
"code-davinci-edit-001": "p50k_edit",
"text-davinci-edit-001": MODEL_P50K_EDIT,
"code-davinci-edit-001": MODEL_P50K_EDIT,
// embeddings
"text-embedding-ada-002": "cl100k_base",
"text-embedding-ada-002": MODEL_CL100K_BASE,
// old embeddings
"text-similarity-davinci-001": "r50k_base",
"text-similarity-curie-001": "r50k_base",
"text-similarity-babbage-001": "r50k_base",
"text-similarity-ada-001": "r50k_base",
"text-search-davinci-doc-001": "r50k_base",
"text-search-curie-doc-001": "r50k_base",
"text-search-babbage-doc-001": "r50k_base",
"text-search-ada-doc-001": "r50k_base",
"code-search-babbage-code-001": "r50k_base",
"code-search-ada-code-001": "r50k_base",
"text-similarity-davinci-001": MODEL_R50K_BASE,
"text-similarity-curie-001": MODEL_R50K_BASE,
"text-similarity-babbage-001": MODEL_R50K_BASE,
"text-similarity-ada-001": MODEL_R50K_BASE,
"text-search-davinci-doc-001": MODEL_R50K_BASE,
"text-search-curie-doc-001": MODEL_R50K_BASE,
"text-search-babbage-doc-001": MODEL_R50K_BASE,
"text-search-ada-doc-001": MODEL_R50K_BASE,
"code-search-babbage-code-001": MODEL_R50K_BASE,
"code-search-ada-code-001": MODEL_R50K_BASE,
// open source
"gpt2": "gpt2",
}

var MODEL_PREFIX_TO_ENCODING = map[string]string{
// chat
"gpt-4-": "cl100k_base", // e.g., gpt-4-0314, etc., plus gpt-4-32k
"gpt-3.5-turbo-": "cl100k_base", // e.g, gpt-3.5-turbo-0301, -0401, etc.
"gpt-4-": MODEL_CL100K_BASE, // e.g., gpt-4-0314, etc., plus gpt-4-32k
"gpt-3.5-turbo-": MODEL_CL100K_BASE, // e.g, gpt-3.5-turbo-0301, -0401, etc.
}

var encodingMap map[string]*Encoding

var onceMaps map[string]*sync.Once

func init() {
encodingMap = make(map[string]*Encoding)
onceMaps = make(map[string]*sync.Once)
for _, encodingName := range MODEL_TO_ENCODING {
if _, ok := onceMaps[encodingName]; !ok {
onceMaps[encodingName] = &sync.Once{}
}
}
}

type Encoding struct {
Expand All @@ -67,27 +89,26 @@ type Encoding struct {
}

func getEncoding(encodingName string) (*Encoding, error) {
encoding, ok := ENCODING_MAP[encodingName]
if !ok {
initEncoding, err := initEncoding(encodingName)
if err != nil {
return nil, err
}
encoding = initEncoding
ENCODING_MAP[encodingName] = encoding
if encoding, ok := encodingMap[encodingName]; ok {
return encoding, nil
}
initEncoding, err := initEncoding(encodingName)
if err != nil {
return nil, err
}
return encoding, nil
onceMaps[encodingName].Do(func() { encodingMap[encodingName] = initEncoding })
return encodingMap[encodingName], nil
}

func initEncoding(encodingName string) (*Encoding, error) {
switch encodingName {
case "cl100k_base":
case MODEL_CL100K_BASE:
return cl100k_base()
case "p50k_base":
case MODEL_P50K_BASE:
return p50k_base()
case "r50k_base":
case MODEL_R50K_BASE:
return r50k_base()
case "p50k_edit":
case MODEL_P50K_EDIT:
return p50k_edit()
default:
return nil, errors.New("Unknown encoding: " + encodingName)
Expand All @@ -107,7 +128,7 @@ func cl100k_base() (*Encoding, error) {
ENDOFPROMPT: 100276,
}
return &Encoding{
Name: "cl100k_base",
Name: MODEL_CL100K_BASE,
PatStr: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
MergeableRanks: ranks,
SpecialTokens: special_tokens,
Expand All @@ -121,7 +142,7 @@ func p50k_edit() (*Encoding, error) {
}
special_tokens := map[string]int{ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283}
return &Encoding{
Name: "p50k_edit",
Name: MODEL_P50K_EDIT,
PatStr: `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`,
MergeableRanks: ranks,
SpecialTokens: special_tokens,
Expand All @@ -143,7 +164,7 @@ func p50k_base() (*Encoding, error) {
// }

return &Encoding{
Name: "p50k_base",
Name: MODEL_P50K_BASE,
PatStr: `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`,
MergeableRanks: ranks,
SpecialTokens: special_tokens,
Expand All @@ -158,12 +179,12 @@ func r50k_base() (*Encoding, error) {
}
special_tokens := map[string]int{ENDOFTEXT: 50256}
return &Encoding{
Name: "r50k_base",
Name: MODEL_R50K_BASE,
MergeableRanks: ranks,
PatStr: `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`,
SpecialTokens: special_tokens,
ExplicitNVocab: 50257,
}, nil
}

var ENCODING_MAP = map[string]*Encoding{}
// var ENCODING_MAP = map[string]*Encoding{}
5 changes: 3 additions & 2 deletions tiktoken_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

func TestEncoding(t *testing.T) {
ass := assert.New(t)
enc, err := GetEncoding("cl100k_base")
enc, err := EncodingForModel("gpt-3.5-turbo-16k")
ass.Nil(err, "Encoding init should not be nil")
tokens := enc.Encode("hello world!你好,世界!", []string{"all"}, []string{"all"})
// these tokens are converted from the original python code
Expand All @@ -33,7 +33,8 @@ func TestEncoding(t *testing.T) {

func TestDecoding(t *testing.T) {
ass := assert.New(t)
enc, err := GetEncoding("cl100k_base")
// enc, err := GetEncoding("cl100k_base")
enc, err := GetEncoding(MODEL_CL100K_BASE)
ass.Nil(err, "Encoding init should not be nil")
sourceTokens := []int{15339, 1917, 0, 57668, 53901, 3922, 3574, 244, 98220, 6447}
txt := enc.Decode(sourceTokens)
Expand Down

0 comments on commit 77200b6

Please sign in to comment.