Skip to content

Commit

Permalink
Issue 430 (#434)
Browse files Browse the repository at this point in the history
  • Loading branch information
tcdsv authored Nov 18, 2023
1 parent 792c95c commit 66b7f71
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 5 deletions.
25 changes: 24 additions & 1 deletion flatten/merge_allof.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,22 @@ func Merge(schema openapi3.SchemaRef) (*openapi3.Schema, error) {
if err != nil {
return nil, err
}
pruneFields(schema)
}

return result.Value, nil
}

// remove fields while maintaining an equivalent schema.
func pruneFields(schema *openapi3.SchemaRef) {
if len(schema.Value.OneOf) == 1 && schema.Value.OneOf[0].Value == schema.Value {
schema.Value.OneOf = nil
}
if len(schema.Value.AnyOf) == 1 && schema.Value.AnyOf[0].Value == schema.Value {
schema.Value.AnyOf = nil
}
}

func mergeCircularAllOf(state *state, baseSchemaRef *openapi3.SchemaRef) error {
schemaRefs := openapi3.SchemaRefs{baseSchemaRef}
schemaRefs = append(schemaRefs, baseSchemaRef.Value.AllOf...)
Expand Down Expand Up @@ -630,8 +641,20 @@ func resolveEnum(values [][]interface{}) ([]interface{}, error) {
}

func resolvePattern(values []string) string {
patterns := []string{}
for _, v := range values {
if len(v) > 0 {
patterns = append(patterns, v)
}
}
if len(patterns) == 0 {
return ""
}
if len(patterns) == 1 {
return patterns[0]
}
var pattern strings.Builder
for _, p := range values {
for _, p := range patterns {
if len(p) > 0 {
if !isPatternResolved(p) {
pattern.WriteString(fmt.Sprintf("(?=%s)", p))
Expand Down
64 changes: 60 additions & 4 deletions flatten/merge_allof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1715,8 +1715,30 @@ func TestMerge_AdditionalProperties_True(t *testing.T) {
}

func TestMergeAllOf_Pattern(t *testing.T) {

merged, err := flatten.Merge(
openapi3.SchemaRef{
Value: &openapi3.Schema{
Pattern: "abc",
}})
require.NoError(t, err)
require.Equal(t, "abc", merged.Pattern)

merged, err = flatten.Merge(
openapi3.SchemaRef{
Value: &openapi3.Schema{
AllOf: openapi3.SchemaRefs{
&openapi3.SchemaRef{
Value: &openapi3.Schema{
Type: "object",
Pattern: "abc",
},
},
},
}})
require.NoError(t, err)
require.Equal(t, "abc", merged.Pattern)

merged, err = flatten.Merge(
openapi3.SchemaRef{
Value: &openapi3.Schema{
AllOf: openapi3.SchemaRefs{
Expand Down Expand Up @@ -1778,14 +1800,48 @@ func TestMerge_CircularAllOf(t *testing.T) {
merged, err := flatten.Merge(*doc.Components.Schemas["AWSEnvironmentSettings"])
require.NoError(t, err)
require.Empty(t, merged.AllOf)

require.Equal(t, "#/components/schemas/AWSEnvironmentSettings", merged.OneOf[0].Ref)
require.Equal(t, &merged, &merged.OneOf[0].Value)
require.Empty(t, merged.OneOf)

require.Equal(t, "string", merged.Properties["serviceEndpoints"].Value.Type)
require.Equal(t, "string", merged.Properties["region"].Value.Type)
}

// A single OneOf field is pruned if it references it's parent schema
func TestMerge_OneOfIsPruned(t *testing.T) {
doc := loadSpec(t, "testdata/circular2.yaml")
merged, err := flatten.Merge(*doc.Components.Schemas["OneOf_Is_Pruned_B"])
require.NoError(t, err)
require.Empty(t, merged.AllOf)
require.Empty(t, merged.OneOf)
}

// A single OneOf field is not pruned if it does not reference it's parent schema
func TestMerge_OneOfIsNotPruned(t *testing.T) {
doc := loadSpec(t, "testdata/circular2.yaml")
merged, err := flatten.Merge(*doc.Components.Schemas["OneOf_Is_Not_Pruned_B"])
require.NoError(t, err)
require.Empty(t, merged.AllOf)
require.NotEmpty(t, merged.OneOf)
}

// A single AnyOf field is pruned if it references it's parent schema
func TestMerge_AnyOfIsPruned(t *testing.T) {
doc := loadSpec(t, "testdata/circular2.yaml")
merged, err := flatten.Merge(*doc.Components.Schemas["AnyOf_Is_Pruned_B"])
require.NoError(t, err)
require.Empty(t, merged.AllOf)
require.Empty(t, merged.AnyOf)
}

// A single AnyOf field is not pruned if it does not reference it's parent schema
func TestMerge_AnyOfIsNotPruned(t *testing.T) {
doc := loadSpec(t, "testdata/circular2.yaml")
merged, err := flatten.Merge(*doc.Components.Schemas["AnyOf_Is_Not_Pruned_B"])
require.NoError(t, err)
require.Empty(t, merged.AllOf)
require.NotEmpty(t, merged.AnyOf)
}

func loadSpec(t *testing.T, path string) *openapi3.T {
ctx := context.Background()
sl := openapi3.NewLoader()
Expand Down
58 changes: 58 additions & 0 deletions flatten/testdata/circular2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
openapi: 3.0.0
info:
title: Circular Reference Example
version: 1.0.0
paths:
/sample:
put:
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/OneOf_Is_Pruned_A'
responses:
'200':
description: Ok

components:
schemas:
OneOf_Is_Pruned_A:
type: object
oneOf:
- $ref: '#/components/schemas/OneOf_Is_Pruned_B'

OneOf_Is_Pruned_B:
type: object
allOf:
- $ref: '#/components/schemas/OneOf_Is_Pruned_A'

OneOf_Is_Not_Pruned_A:
type: object
oneOf:
- type: object

OneOf_Is_Not_Pruned_B:
type: object
allOf:
- $ref: '#/components/schemas/OneOf_Is_Not_Pruned_A'

AnyOf_Is_Pruned_A:
type: object
anyOf:
- $ref: '#/components/schemas/AnyOf_Is_Pruned_B'

AnyOf_Is_Pruned_B:
type: object
allOf:
- $ref: '#/components/schemas/AnyOf_Is_Pruned_A'

AnyOf_Is_Not_Pruned_A:
type: object
anyOf:
- type: object

AnyOf_Is_Not_Pruned_B:
type: object
allOf:
- $ref: '#/components/schemas/AnyOf_Is_Not_Pruned_A'

0 comments on commit 66b7f71

Please sign in to comment.