Skip to content

Commit

Permalink
Allow users to pass node styler rules to the visualizer component
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 665412331
  • Loading branch information
Google AI Edge authored and copybara-github committed Aug 20, 2024
1 parent 0183cf6 commit b00be73
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 60 deletions.
38 changes: 32 additions & 6 deletions src/ui/src/components/visualizer/common/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -536,12 +536,38 @@ export declare interface ShowOnEdgeItemData {
selected: boolean;
}

/** The ids of the node style. */
export enum NodeStyleId {
NODE_BG_COLOR = 'node_bg_color',
NODE_TEXT_COLOR = 'node_text_color',
NODE_BORDER_COLOR = 'node_border_color',
}

/** A rule for node styler. All fields should be serializable. */
export declare interface NodeStylerRule {
/**
* Quries are connected with AND.
*/
queries: NodeQuery[];
nodeType: SearchNodeType;
// Indexed by style ids.
styles: Record<string, SerializedStyle>;

/**
* The type of node to match.
*
* @deprecated The new version the of rule stores node type as a query in
* `queries` above.
*/
nodeType?: SearchNodeType;

/**
* Styles applied to the matched nodes.
*
* Indexed by style ids.
*/
styles: Partial<Record<NodeStyleId, SerializedStyle | string>>;

/**
* Should set this to V2.
*/
version?: NodeStylerRuleVersion;
}

Expand All @@ -557,8 +583,8 @@ declare interface NodeQueryBase {
/** A rule width processed node styler rules. */
export interface ProcessedNodeStylerRule {
queries: ProcessedNodeQuery[];
nodeType: SearchNodeType;
styles: Record<string, SerializedStyle>;
nodeType?: SearchNodeType;
styles: Record<string, SerializedStyle | string>;
}

declare interface ProcessedNodeQueryBase {
Expand Down Expand Up @@ -621,6 +647,6 @@ export enum SearchNodeType {

/** Serialized style. */
export declare interface SerializedStyle {
id: string;
id: NodeStyleId;
value: string;
}
17 changes: 17 additions & 0 deletions src/ui/src/components/visualizer/common/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import {
NodeDataProviderRunData,
NodeQuery,
NodeQueryType,
NodeStyleId,
NodeStylerRule,
Point,
ProcessedNodeQuery,
Expand Down Expand Up @@ -933,3 +934,19 @@ export function getHighQualityPixelRatio(): number {
? 1.5 /* This makes rendering result sharper on non-retina displays */
: window.devicePixelRatio;
}

/** Get the value for the given style. */
export function getNodeStyleValue(
rule: ProcessedNodeStylerRule | NodeStylerRule,
styleId: NodeStyleId,
): string {
const curStyle = rule.styles[styleId];
if (curStyle) {
if (typeof curStyle === 'string') {
return curStyle;
} else {
return curStyle.value;
}
}
return '';
}
7 changes: 6 additions & 1 deletion src/ui/src/components/visualizer/common/visualizer_config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* ==============================================================================
*/

import {RendererType} from './types';
import {NodeStylerRule, RendererType} from './types';

/** Configs for the visualizer. */
export declare interface VisualizerConfig {
Expand Down Expand Up @@ -50,8 +50,13 @@ export declare interface VisualizerConfig {
/** Whether to keep layers with a single child. */
keepLayersWithASingleChild?: boolean;

/** The default node styler rules. */
nodeStylerRules?: NodeStylerRule[];

/**
* Default graph renderer.
*
* @deprecated This field is no longer used.
*/
defaultRenderer?: RendererType;

Expand Down
4 changes: 4 additions & 0 deletions src/ui/src/components/visualizer/model_graph_visualizer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ export class ModelGraphVisualizer implements OnInit, OnDestroy, OnChanges {
private readonly threejsService: ThreejsService,
private readonly uiStateService: UiStateService,
private readonly nodeDataProviderExtensionService: NodeDataProviderExtensionService,
private readonly nodeStylerService: NodeStylerService,
) {
effect(() => {
const curUiState = this.uiStateService.curUiState();
Expand Down Expand Up @@ -184,6 +185,9 @@ export class ModelGraphVisualizer implements OnInit, OnDestroy, OnChanges {
this.appService.config.set(this.config || {});
this.appService.addGraphCollections(this.graphCollections);
this.appService.curInitialUiState.set(this.initialUiState);
if (this.config?.nodeStylerRules) {
this.nodeStylerService.rules.set(this.config.nodeStylerRules);
}

// No initial ui state. Use the graph with the most node counts as the
// default selected graph.
Expand Down
10 changes: 2 additions & 8 deletions src/ui/src/components/visualizer/node_styler_dialog.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import {
SearchMatchType,
SearchNodeType,
} from './common/types';
import {getNodeStyleValue} from './common/utils';
import {ComplexQueries} from './complex_queries';
import {NodeListViewer} from './node_list_viewer';
import {
Expand Down Expand Up @@ -205,13 +206,6 @@ export class NodeStylerDialog {
this.nodeStylerService.updateStyleValue(index, style, `${n}`);
}

handleNodeTypeChanged(ruleIndex: number, nodeType: string) {
this.nodeStylerService.updateNodeType(
ruleIndex,
nodeType as SearchNodeType,
);
}

handleMoveUpRule(index: number) {
this.nodeStylerService.moveUpRule(index);
}
Expand All @@ -233,7 +227,7 @@ export class NodeStylerDialog {
}

getSerializedStyleValue(rule: NodeStylerRule, style: Style): string {
return rule.styles[style.id]?.value || '';
return getNodeStyleValue(rule, style.id);
}

getMatchedNodes(ruleIndex: number, paneIndex: number): ModelNode[] {
Expand Down
36 changes: 14 additions & 22 deletions src/ui/src/components/visualizer/node_styler_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {ModelNode} from './common/model_graph';
import {
NodeQuery,
NodeQueryType,
NodeStyleId,
NodeStylerRule,
NodeStylerRuleVersion,
NodeTypeQuery,
Expand All @@ -41,7 +42,7 @@ import {LocalStorageService} from './local_storage_service';
export interface Style {
type: StyleType;
label: string;
id: string;
id: NodeStyleId;
defaultValue: string;
}

Expand All @@ -51,35 +52,27 @@ export enum StyleType {
NUMBER = 'NUMBER',
}

/** Style ids. */
export enum StyleId {
NODE_BG_COLOR = 'node_bg_color',
NODE_TEXT_COLOR = 'node_text_color',
NODE_BORDER_COLOR = 'node_border_color',
NODE_SCALE = 'node_scale',
}

/** Style for node background color. */
export const NODE_BG_COLOR_STYLE: Style = {
type: StyleType.COLOR,
label: 'Bg color',
id: StyleId.NODE_BG_COLOR,
id: NodeStyleId.NODE_BG_COLOR,
defaultValue: '#ffffff',
};

/** Style for node border color. */
export const NODE_BORDER_COLOR_STYLE: Style = {
type: StyleType.COLOR,
label: 'Border color',
id: StyleId.NODE_BORDER_COLOR,
id: NodeStyleId.NODE_BORDER_COLOR,
defaultValue: '#777777',
};

/** Style for node text color. */
export const NODE_TEXT_COLOR_STYLE: Style = {
type: StyleType.COLOR,
label: 'Text color',
id: StyleId.NODE_TEXT_COLOR,
id: NodeStyleId.NODE_TEXT_COLOR,
defaultValue: '#041e49',
};

Expand Down Expand Up @@ -251,18 +244,17 @@ export class NodeStylerService {
});
}

updateNodeType(ruleIndex: number, nodeType: SearchNodeType) {
this.rules.update((rules) => {
const rule = rules[ruleIndex];
rule.nodeType = nodeType;
return [...rules];
});
}

updateStyleValue(ruleIndex: number, style: Style, value: string) {
this.rules.update((rules) => {
const rule = rules[ruleIndex];
rule.styles[style.id].value = value;
const curStyle = rule.styles[style.id];
if (curStyle) {
if (typeof curStyle === 'string') {
rule.styles[style.id] = value;
} else {
curStyle.value = value;
}
}
return [...rules];
});
}
Expand All @@ -279,7 +271,7 @@ export class NodeStylerService {
return rules.map((rule) => {
// For older version of the rule, convert the node type to
// a query.
if (rule.version == null) {
if (rule.version == null && rule.nodeType) {
const nodeTypeQuery: NodeTypeQuery = {
type: NodeQueryType.NODE_TYPE,
nodeType: rule.nodeType,
Expand Down
18 changes: 3 additions & 15 deletions src/ui/src/components/visualizer/renderer_wrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import {
effect,
EventEmitter,
Input,
OnInit,
Output,
ViewChild,
} from '@angular/core';
Expand All @@ -43,7 +42,6 @@ import {AppService} from './app_service';
import {type ModelGraph} from './common/model_graph';
import {
PopupPanelData,
RendererType,
SelectedNodeInfo,
SubgraphBreadcrumbItem,
} from './common/types';
Expand All @@ -54,8 +52,6 @@ import {SubgraphBreadcrumbs} from './subgraph_breadcrumbs';
import {ViewOnNode} from './view_on_node';
import {WebglRenderer} from './webgl_renderer';

const DEFAULT_RENDERER_TYPE: RendererType = RendererType.WEBGL;

/** A wrapper panel around various renderers. */
@Component({
standalone: true,
Expand All @@ -79,7 +75,7 @@ const DEFAULT_RENDERER_TYPE: RendererType = RendererType.WEBGL;
styleUrls: ['./renderer_wrapper.scss'],
changeDetection: ChangeDetectionStrategy.OnPush,
})
export class RendererWrapper implements OnInit {
export class RendererWrapper {
@Input({required: true}) modelGraph!: ModelGraph;
@Input({required: true}) rendererId!: string;
@Input({required: true}) paneId!: string;
Expand All @@ -90,8 +86,6 @@ export class RendererWrapper implements OnInit {

@ViewChild('webglRenderer') webglRenderer?: WebglRenderer;

readonly RendererType = RendererType;

readonly helpPopupSize: OverlaySizeConfig = {
minWidth: 0,
minHeight: 0,
Expand All @@ -101,7 +95,6 @@ export class RendererWrapper implements OnInit {
flattenAllLayers = computed(() =>
this.appService.getFlattenLayers(this.paneId),
);
rendererType = DEFAULT_RENDERER_TYPE;
disableDownloadPngHelpPopup = false;
transparentPngBackground = new FormControl<boolean>(false);

Expand All @@ -118,11 +111,6 @@ export class RendererWrapper implements OnInit {
});
}

ngOnInit() {
this.rendererType =
this.appService.config()?.defaultRenderer ?? DEFAULT_RENDERER_TYPE;
}

handleOpenOnPopupClicked(data: PopupPanelData) {
this.openInPopupClicked.emit(data);
}
Expand Down Expand Up @@ -205,11 +193,11 @@ export class RendererWrapper implements OnInit {
}

get showDownloadPng(): boolean {
return !this.inPopup && this.rendererType === RendererType.WEBGL;
return !this.inPopup;
}

get showSnapshotManager(): boolean {
return !this.inPopup && this.rendererType === RendererType.WEBGL;
return !this.inPopup;
}

get showSubgraphBreadcrumbs(): boolean {
Expand Down
27 changes: 19 additions & 8 deletions src/ui/src/components/visualizer/webgl_renderer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ import {
FontWeight,
NodeDataProviderResultProcessedData,
NodeDataProviderRunData,
NodeStyleId,
NodeStylerRule,
Point,
PopupPanelData,
Expand All @@ -79,6 +80,7 @@ import {
import {
genUid,
getHighQualityPixelRatio,
getNodeStyleValue,
hasNonEmptyQueries,
IS_MAC,
isGroupNode,
Expand All @@ -97,7 +99,7 @@ import {
import {DragArea} from './drag_area';
import {genIoTreeData, IoTree} from './io_tree';
import {NodeDataProviderExtensionService} from './node_data_provider_extension_service';
import {NodeStylerService, StyleId} from './node_styler_service';
import {NodeStylerService} from './node_styler_service';
import {SplitPaneService} from './split_pane_service';
import {SubgraphSelectionService} from './subgraph_selection_service';
import {ThreejsService} from './threejs_service';
Expand Down Expand Up @@ -1767,17 +1769,24 @@ export class WebglRenderer implements OnInit, OnDestroy {
// Node styler.
for (const rule of this.curProcessedNodeStylerRules) {
if (matchNodeForQueries(node, rule.queries, this.curModelGraph)) {
const nodeStylerBgColor =
rule.styles[StyleId.NODE_BG_COLOR]?.value || '';
const nodeStylerBgColor = getNodeStyleValue(
rule,
NodeStyleId.NODE_BG_COLOR,
);
if (nodeStylerBgColor !== '') {
bgColor = new THREE.Color(nodeStylerBgColor);
}
const nodeBorderColor =
rule.styles[StyleId.NODE_BORDER_COLOR]?.value || '';
const nodeBorderColor = getNodeStyleValue(
rule,
NodeStyleId.NODE_BORDER_COLOR,
);
if (nodeBorderColor !== '') {
borderColor = new THREE.Color(nodeBorderColor);
}
const textColor = rule.styles[StyleId.NODE_TEXT_COLOR]?.value || '';
const textColor = getNodeStyleValue(
rule,
NodeStyleId.NODE_TEXT_COLOR,
);
if (textColor !== '') {
groupNodeIconColor = new THREE.Color(textColor);
}
Expand Down Expand Up @@ -2019,8 +2028,10 @@ export class WebglRenderer implements OnInit, OnDestroy {
// Node styler.
for (const rule of this.curProcessedNodeStylerRules) {
if (matchNodeForQueries(node, rule.queries, this.curModelGraph)) {
const nodeStylerTextColor =
rule.styles[StyleId.NODE_TEXT_COLOR]?.value || '';
const nodeStylerTextColor = getNodeStyleValue(
rule,
NodeStyleId.NODE_TEXT_COLOR,
);
if (nodeStylerTextColor !== '') {
color = new THREE.Color(nodeStylerTextColor);
}
Expand Down

0 comments on commit b00be73

Please sign in to comment.