-
-
Notifications
You must be signed in to change notification settings - Fork 179
/
db_test.go
168 lines (150 loc) · 3.97 KB
/
db_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
package dynamo
import (
"context"
"errors"
"fmt"
"log"
"os"
"strconv"
"strings"
"testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/retry"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/smithy-go"
)
var (
testDB *DB
testTableWidgets = "TestDB"
testTableSprockets = "TestDB-Sprockets"
)
var dummyCreds = credentials.NewStaticCredentialsProvider("dummy", "dummy", "")
const offlineSkipMsg = "DYNAMO_TEST_REGION not set"
// widget is the data structure used for integration tests
type widget struct {
UserID int `dynamo:",hash"`
Time time.Time `dynamo:",range" index:"Msg-Time-index,range"`
Msg string `index:"Msg-Time-index,hash"`
Count int
Meta map[string]string
StrPtr *string `dynamo:",allowempty"`
}
func TestMain(m *testing.M) {
var endpoint, region *string
if dte := os.Getenv("DYNAMO_TEST_ENDPOINT"); dte != "" {
endpoint = &dte
}
if dtr := os.Getenv("DYNAMO_TEST_REGION"); dtr != "" {
region = &dtr
}
if endpoint != nil && region == nil {
dtr := "local"
region = &dtr
}
if region != nil {
var resolv aws.EndpointResolverWithOptions
if endpoint != nil {
resolv = aws.EndpointResolverWithOptionsFunc(
func(service, region string, options ...interface{}) (aws.Endpoint, error) {
return aws.Endpoint{URL: *endpoint}, nil
},
)
}
// TransactionCanceledException
cfg, err := config.LoadDefaultConfig(
context.Background(),
config.WithRegion(*region),
config.WithEndpointResolverWithOptions(resolv),
config.WithRetryer(func() aws.Retryer {
return retry.NewStandard(RetryTxConflicts)
}),
)
if err != nil {
log.Fatal(err)
}
testDB = New(cfg)
}
timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10)
var offline bool
if table := os.Getenv("DYNAMO_TEST_TABLE"); table != "" {
offline = false
// Test-% --> Test-1707708680863
table = strings.ReplaceAll(table, "%", timestamp)
testTableWidgets = table
}
if table := os.Getenv("DYNAMO_TEST_TABLE2"); table != "" {
table = strings.ReplaceAll(table, "%", timestamp)
testTableSprockets = table
} else if !offline {
testTableSprockets = testTableWidgets + "-Sprockets"
}
if !offline && testTableWidgets == testTableSprockets {
panic(fmt.Sprintf("DYNAMO_TEST_TABLE must not equal DYNAMO_TEST_TABLE2. got DYNAMO_TEST_TABLE=%q and DYNAMO_TEST_TABLE2=%q",
testTableWidgets, testTableSprockets))
}
var shouldCreate bool
switch os.Getenv("DYNAMO_TEST_CREATE_TABLE") {
case "1", "true", "yes":
shouldCreate = true
case "0", "false", "no":
shouldCreate = false
default:
shouldCreate = endpoint != nil
}
ctx := context.Background()
var created []Table
if testDB != nil {
for _, name := range []string{testTableWidgets, testTableSprockets} {
table := testDB.Table(name)
log.Println("Checking test table:", name)
_, err := table.Describe().Run(ctx)
switch {
case isTableNotExistsErr(err) && shouldCreate:
log.Println("Creating test table:", name)
if err := testDB.CreateTable(name, widget{}).Run(ctx); err != nil {
panic(err)
}
created = append(created, testDB.Table(name))
case err != nil:
panic(err)
}
}
}
code := m.Run()
defer os.Exit(code)
for _, table := range created {
log.Println("Deleting test table:", table.Name())
if err := table.DeleteTable().Run(ctx); err != nil {
log.Println("Error deleting test table:", table.Name(), err)
}
}
}
func isTableNotExistsErr(err error) bool {
var aerr smithy.APIError
if errors.As(err, &aerr) {
return aerr.ErrorCode() == "ResourceNotFoundException"
}
return false
}
func TestListTables(t *testing.T) {
if testDB == nil {
t.Skip(offlineSkipMsg)
}
tables, err := testDB.ListTables().All(context.TODO())
if err != nil {
t.Error(err)
return
}
found := false
for _, t := range tables {
if t == testTableWidgets {
found = true
break
}
}
if !found {
t.Error("couldn't find testTable", testTableWidgets, "in:", tables)
}
}