Skip to content

Commit

Permalink
materialize-postgres: create load table columns based on existing tab…
Browse files Browse the repository at this point in the history
…le key columns

In scenarios where there may be more than one allowed pre-existing column type,
the load table must be created with a column type corresponding to the existing
table column to ensure that join query comparisons work correctly.

This required threading through a hydrated `InfoSchema` to the `NewTransactor`
constructor, and that has been added as a generally available capability. This
simplifies `materialize-redshift` and `materialize-mysql` which were already
making an `InfoSchema` in their own bespoke way, and eventually most other
materializations will probably need to do something like `materialize-postgres`
does with its load table columns.

Currently the only other materialization that matches its load table column
types to the existing table column types is `materialize-bigquery`, and it may
be useful to refactor this handling in terms of the `InfoSchema` at some point
as well.
  • Loading branch information
williamhbaker committed Sep 30, 2024
1 parent 2292069 commit c0c172f
Show file tree
Hide file tree
Showing 13 changed files with 62 additions and 63 deletions.
1 change: 1 addition & 0 deletions materialize-bigquery/transactor.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func newTransactor(
fence sql.Fence,
bindings []sql.Table,
open pm.Request_Open,
is *boilerplate.InfoSchema,
) (_ m.Transactor, err error) {
cfg := ep.Config.(*config)

Expand Down
1 change: 1 addition & 0 deletions materialize-databricks/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ func newTransactor(
fence sql.Fence,
bindings []sql.Table,
open pm.Request_Open,
is *boilerplate.InfoSchema,
) (_ m.Transactor, err error) {
var cfg = ep.Config.(*config)

Expand Down
2 changes: 2 additions & 0 deletions materialize-motherduck/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
m "github.com/estuary/connectors/go/protocols/materialize"
boilerplate "github.com/estuary/connectors/materialize-boilerplate"
sql "github.com/estuary/connectors/materialize-sql"
pf "github.com/estuary/flow/go/protocols/flow"
pm "github.com/estuary/flow/go/protocols/materialize"
Expand Down Expand Up @@ -197,6 +198,7 @@ func newTransactor(
fence sql.Fence,
bindings []sql.Table,
open pm.Request_Open,
is *boilerplate.InfoSchema,
) (_ m.Transactor, err error) {
cfg := ep.Config.(*config)

Expand Down
18 changes: 2 additions & 16 deletions materialize-mysql/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,13 +379,14 @@ func (t *transactor) Acknowledge(ctx context.Context) (*pf.ConnectorState, error
func prepareNewTransactor(
dialect sql.Dialect,
templates templates,
) func(context.Context, *sql.Endpoint, sql.Fence, []sql.Table, pm.Request_Open) (m.Transactor, error) {
) func(context.Context, *sql.Endpoint, sql.Fence, []sql.Table, pm.Request_Open, *boilerplate.InfoSchema) (m.Transactor, error) {
return func(
ctx context.Context,
ep *sql.Endpoint,
fence sql.Fence,
bindings []sql.Table,
open pm.Request_Open,
is *boilerplate.InfoSchema,
) (_ m.Transactor, err error) {
var cfg = ep.Config.(*config)
var d = &transactor{dialect: dialect, templates: templates, cfg: cfg}
Expand All @@ -403,21 +404,6 @@ func prepareNewTransactor(
return nil, fmt.Errorf("store db.Conn: %w", err)
}

db, err := stdsql.Open("mysql", cfg.ToURI())
if err != nil {
return nil, fmt.Errorf("newTransactor sql.Open: %w", err)
}
defer db.Close()

resourcePaths := make([][]string, 0, len(open.Materialization.Bindings))
for _, b := range open.Materialization.Bindings {
resourcePaths = append(resourcePaths, b.ResourcePath)
}
is, err := sql.StdFetchInfoSchema(ctx, db, ep.Dialect, "def", resourcePaths)
if err != nil {
return nil, err
}

for _, binding := range bindings {
if err = d.addBinding(ctx, binding, is); err != nil {
return nil, fmt.Errorf("addBinding of %s: %w", binding.Path, err)
Expand Down
38 changes: 26 additions & 12 deletions materialize-postgres/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ func newTransactor(
fence sql.Fence,
bindings []sql.Table,
open pm.Request_Open,
is *boilerplate.InfoSchema,
) (_ m.Transactor, err error) {
var cfg = ep.Config.(*config)

Expand All @@ -283,7 +284,7 @@ func newTransactor(
}

for _, binding := range bindings {
if err = d.addBinding(ctx, binding); err != nil {
if err = d.addBinding(ctx, binding, is); err != nil {
return nil, fmt.Errorf("addBinding of %s: %w", binding.Path, err)
}
}
Expand All @@ -299,23 +300,21 @@ func newTransactor(
}

type binding struct {
target sql.Table
createLoadTableSQL string
loadInsertSQL string
storeUpdateSQL string
storeInsertSQL string
deleteQuerySQL string
loadQuerySQL string
target sql.Table
loadInsertSQL string
storeUpdateSQL string
storeInsertSQL string
deleteQuerySQL string
loadQuerySQL string
}

func (t *transactor) addBinding(ctx context.Context, target sql.Table) error {
func (t *transactor) addBinding(ctx context.Context, target sql.Table, is *boilerplate.InfoSchema) error {
var b = &binding{target: target}

for _, m := range []struct {
sql *string
tpl *template.Template
}{
{&b.createLoadTableSQL, tplCreateLoadTable},
{&b.loadInsertSQL, tplLoadInsert},
{&b.storeInsertSQL, tplStoreInsert},
{&b.storeUpdateSQL, tplStoreUpdate},
Expand All @@ -331,8 +330,23 @@ func (t *transactor) addBinding(ctx context.Context, target sql.Table) error {
t.bindings = append(t.bindings, b)

// Create a binding-scoped temporary table for staged keys to load.
if _, err := t.load.conn.Exec(ctx, b.createLoadTableSQL); err != nil {
return fmt.Errorf("Exec(%s): %w", b.createLoadTableSQL, err)
input := loadTableColumns{Binding: b.target.Binding}
for _, k := range b.target.Keys {
existing, err := is.GetField(b.target.Path, k.Field)
if err != nil {
return fmt.Errorf("getting existing key field %s for binding %s: %w", k.Field, b.target.Path, err)
}
input.Keys = append(input.Keys, loadTableKey{
Identifier: k.Identifier,
DDL: existing.Type + " NOT NULL", // nullable key fields are not allowed
})
}

var w strings.Builder
if err := tplCreateLoadTable.Execute(&w, &input); err != nil {
return fmt.Errorf("executing createLoadTable template: %w", err)
} else if _, err := t.load.conn.Exec(ctx, w.String()); err != nil {
return fmt.Errorf("Exec(%s): %w", w.String(), err)
}

return nil
Expand Down
10 changes: 10 additions & 0 deletions materialize-postgres/sqlgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ var pgDialect = func() sql.Dialect {
}
}()

type loadTableKey struct {
Identifier string
DDL string
}

type loadTableColumns struct {
Binding int
Keys []loadTableKey
}

var (
tplAll = sql.MustParseTemplate(pgDialect, "root", `
{{ define "temp_name" -}}
Expand Down
34 changes: 2 additions & 32 deletions materialize-redshift/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"fmt"
"net"
"net/url"
"slices"
"strings"
"text/template"

Expand Down Expand Up @@ -286,13 +285,14 @@ type transactor struct {
func prepareNewTransactor(
templates templates,
caseSensitiveIdentifierEnabled bool,
) func(context.Context, *sql.Endpoint, sql.Fence, []sql.Table, pm.Request_Open) (m.Transactor, error) {
) func(context.Context, *sql.Endpoint, sql.Fence, []sql.Table, pm.Request_Open, *boilerplate.InfoSchema) (m.Transactor, error) {
return func(
ctx context.Context,
ep *sql.Endpoint,
fence sql.Fence,
bindings []sql.Table,
open pm.Request_Open,
is *boilerplate.InfoSchema,
) (_ m.Transactor, err error) {
var cfg = ep.Config.(*config)

Expand All @@ -314,36 +314,6 @@ func prepareNewTransactor(
return nil, err
}

db, err := stdsql.Open("pgx", d.cfg.toURI())
if err != nil {
return nil, err
}
defer db.Close()

schemas := []string{}
for _, b := range bindings {
if !slices.Contains(schemas, b.InfoLocation.TableSchema) {
schemas = append(schemas, b.InfoLocation.TableSchema)
}
}

catalog := cfg.Database
if catalog == "" {
// An endpoint-level database configuration is not required, so query for the active
// database if that's the case.
if err := db.QueryRowContext(ctx, "select current_database();").Scan(&catalog); err != nil {
return nil, fmt.Errorf("querying for connected database: %w", err)
}
}
resourcePaths := make([][]string, 0, len(open.Materialization.Bindings))
for _, b := range open.Materialization.Bindings {
resourcePaths = append(resourcePaths, b.ResourcePath)
}
is, err := sql.StdFetchInfoSchema(ctx, db, ep.Dialect, catalog, resourcePaths)
if err != nil {
return nil, err
}

for idx, target := range bindings {
if err = d.addBinding(
idx,
Expand Down
1 change: 1 addition & 0 deletions materialize-snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ func newTransactor(
fence sql.Fence,
bindings []sql.Table,
open pm.Request_Open,
is *boilerplate.InfoSchema,
) (_ m.Transactor, err error) {
var cfg = ep.Config.(*config)

Expand Down
13 changes: 12 additions & 1 deletion materialize-sql/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,10 @@ func (d *Driver) NewTransactor(ctx context.Context, open pm.Request_Open) (m.Tra
}
defer client.Close()

var resourcePaths [][]string
if endpoint.MetaSpecs != nil {
resourcePaths = append(resourcePaths, endpoint.MetaSpecs.Path)

if _, loadedVersion, err = loadSpec(ctx, client, endpoint, open.Materialization.Name); err != nil {
return nil, nil, fmt.Errorf("loading prior applied materialization spec: %w", err)
} else if loadedVersion == "" {
Expand All @@ -261,6 +264,7 @@ func (d *Driver) NewTransactor(ctx context.Context, open pm.Request_Open) (m.Tra
var tables []Table
for index, spec := range open.Materialization.Bindings {
var resource = endpoint.NewResource(endpoint)
resourcePaths = append(resourcePaths, resource.Path())

if err := pf.UnmarshalStrict(spec.ResourceConfigJson, resource); err != nil {
return nil, nil, fmt.Errorf("resource binding for collection %q: %w", spec.Collection.Name, err)
Expand All @@ -285,6 +289,8 @@ func (d *Driver) NewTransactor(ctx context.Context, open pm.Request_Open) (m.Tra
}

if endpoint.MetaCheckpoints != nil {
resourcePaths = append(resourcePaths, endpoint.MetaCheckpoints.Path)

// We must install a fence to prevent another (zombie) instances of this
// materialization from committing further transactions.
var metaCheckpoints, err = ResolveTable(*endpoint.MetaCheckpoints, endpoint.Dialect)
Expand All @@ -303,7 +309,12 @@ func (d *Driver) NewTransactor(ctx context.Context, open pm.Request_Open) (m.Tra
}
}

transactor, err := endpoint.NewTransactor(ctx, endpoint, fence, tables, open)
is, err := client.InfoSchema(ctx, resourcePaths)
if err != nil {
return nil, nil, fmt.Errorf("getting info schema: %w", err)
}

transactor, err := endpoint.NewTransactor(ctx, endpoint, fence, tables, open, is)
if err != nil {
return nil, nil, fmt.Errorf("building transactor: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion materialize-sql/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ type Endpoint struct {
// which will be parsed into and validated from a resource configuration.
NewResource func(*Endpoint) Resource
// NewTransactor returns a Transactor ready for pm.RunTransactions.
NewTransactor func(ctx context.Context, _ *Endpoint, _ Fence, bindings []Table, open pm.Request_Open) (m.Transactor, error)
NewTransactor func(ctx context.Context, _ *Endpoint, _ Fence, bindings []Table, open pm.Request_Open, is *boilerplate.InfoSchema) (m.Transactor, error)
// Tenant owning this task, as determined from the task name.
Tenant string
// ConcurrentApply of Apply actions, for system that may benefit from a scatter/gather strategy
Expand Down
1 change: 1 addition & 0 deletions materialize-sqlite/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ func newTransactor(
fence sql.Fence,
bindings []sql.Table,
open pm.Request_Open,
is *boilerplate.InfoSchema,
) (_ m.Transactor, err error) {
var d = &transactor{
dialect: &sqliteDialect,
Expand Down
3 changes: 2 additions & 1 deletion materialize-sqlserver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,14 @@ type transactor struct {

func prepareNewTransactor(
templates templates,
) func(context.Context, *sql.Endpoint, sql.Fence, []sql.Table, pm.Request_Open) (m.Transactor, error) {
) func(context.Context, *sql.Endpoint, sql.Fence, []sql.Table, pm.Request_Open, *boilerplate.InfoSchema) (m.Transactor, error) {
return func(
ctx context.Context,
ep *sql.Endpoint,
fence sql.Fence,
bindings []sql.Table,
open pm.Request_Open,
is *boilerplate.InfoSchema,
) (_ m.Transactor, err error) {
var cfg = ep.Config.(*config)
var d = &transactor{templates: templates, cfg: cfg}
Expand Down
1 change: 1 addition & 0 deletions materialize-starburst/starburst.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ func newTransactor(
_ sql.Fence,
tables []sql.Table,
open pm.Request_Open,
is *boilerplate.InfoSchema,
) (_ m.Transactor, err error) {
var cfg = ep.Config.(*config)
var templates = renderTemplates(starburstTrinoDialect)
Expand Down

0 comments on commit c0c172f

Please sign in to comment.