diff --git a/backend/internal/license/license_detector.go b/backend/internal/license/license_detector.go index 9b26bdef..de9117c7 100644 --- a/backend/internal/license/license_detector.go +++ b/backend/internal/license/license_detector.go @@ -119,6 +119,16 @@ func WithConfidenceThreshold(threshold float32) Opt { } } +func WithConfidenceOverrideThreshold(threshold float32) Opt { + return func(config *Config) error { + if threshold < 0 || threshold > 1 { + return fmt.Errorf("invalid threshold: %f", threshold) + } + config.ConfidenceOverrideThreshold = threshold + return nil + } +} + func WithCompatibleLicenses(licenses ...string) Opt { return func(config *Config) error { config.CompatibleLicenses = append(config.CompatibleLicenses, licenses...) @@ -129,12 +139,18 @@ func WithCompatibleLicenses(licenses ...string) Opt { type Config struct { CompatibleLicenses []string ConfidenceThreshold float32 + // ConfidenceOverrideThreshold is the limit at which a detected license overrides all other detected licenses. + // Defaults to 98%. + ConfidenceOverrideThreshold float32 } func (c *Config) ApplyDefaults() error { if c.ConfidenceThreshold == 0.0 { c.ConfidenceThreshold = 0.9 } + if c.ConfidenceOverrideThreshold == 0 { + c.ConfidenceOverrideThreshold = 0.98 + } return nil } @@ -142,6 +158,9 @@ func (c *Config) Validate() error { if len(c.CompatibleLicenses) == 0 { return fmt.Errorf("no licenses configured") } + if c.ConfidenceOverrideThreshold < c.ConfidenceThreshold { + return fmt.Errorf("the confidence override threshold (%f) is lower than the confidence threshold (%f)", c.ConfidenceOverrideThreshold, c.ConfidenceThreshold) + } return nil } @@ -230,6 +249,11 @@ func (d detector) Detect(_ context.Context, repository fs.ReadDirFS, detectOptio if err := detectCfg.LinkFetcher(&l); err != nil { return nil, err } + if match.Confidence >= d.config.ConfidenceOverrideThreshold { + return []License{ + l, + }, nil + } result = append(result, l) } }