Skip to content

Commit

Permalink
Add support for adding custom attributes to group/layer nodes.
Browse files Browse the repository at this point in the history
- Hook it up with search and view on node.
- Also, minor fix on the expand/collapse icon for info panel sections.
  Now collapsed sections will show ">" icon. and expanded sections will show "V".

PiperOrigin-RevId: 662271571
  • Loading branch information
Google AI Edge authored and copybara-github committed Aug 13, 2024
1 parent b29ce2d commit d5edf3b
Show file tree
Hide file tree
Showing 13 changed files with 220 additions and 17 deletions.
8 changes: 8 additions & 0 deletions src/ui/src/components/visualizer/common/input_graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import {
GraphNodeConfig,
GraphNodeStyle,
GroupNodeAttributes,
IncomingEdge,
KeyValueList,
MetadataItem,
Expand Down Expand Up @@ -68,6 +69,13 @@ export declare interface Graph {
/** A list of nodes in the graph. */
nodes: GraphNode[];

/**
* Attributes for group nodes.
*
* It is displayed in the side panel when the group is selected.
*/
groupNodeAttributes?: GroupNodeAttributes;

//////////////////////////////////////////////////////////////////////////////
// The following fields are set by model explorer. Users don't need to set
// them.
Expand Down
4 changes: 4 additions & 0 deletions src/ui/src/components/visualizer/common/model_graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import {
GraphNodeConfig,
GraphNodeStyle,
GroupNodeAttributes,
IncomingEdge,
KeyValuePairs,
OutgoingEdge,
Expand All @@ -45,6 +46,9 @@ export declare interface ModelGraph {
/** All nodes in the model graph. */
nodes: ModelNode[];

/** Attributes for group nodes. */
groupNodeAttributes?: GroupNodeAttributes;

/** Ids of all group nodes that are artificially created. */
artificialGroupNodeIds?: string[];

Expand Down
15 changes: 15 additions & 0 deletions src/ui/src/components/visualizer/common/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,20 @@ export declare interface Rect {
height: number;
}

/** Attributes for group nodes. */
export declare interface GroupNodeAttributes {
/**
* From group's namespace to its attribuets (key-value pairs).
*
* Use empty group namespace for the model-level attributes (i.e. shown in
* side panel when no node is selected).
*/
[namespaceName: string]: Record<string, GroupNodeAttributeItem>;
}

/** A single attribute item for group node. */
export type GroupNodeAttributeItem = string;

/** The style of the op node. */
export declare interface GraphNodeStyle {
/**
Expand Down Expand Up @@ -475,6 +489,7 @@ export enum ShowOnNodeItemType {
OP_OUTPUTS = 'Op node outputs',
LAYER_NODE_CHILDREN_COUNT = 'Layer node children count',
LAYER_NODE_DESCENDANTS_COUNT = 'Layer node descendants count',
LAYER_NODE_ATTRS = 'Layer node attributes',
}

/** Item types to be shown on edge. */
Expand Down
46 changes: 42 additions & 4 deletions src/ui/src/components/visualizer/common/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ export function getOpNodeAttrsKeyValuePairsForAttrsTable(
processedValue = value.replace(/\s/gm, '');
} else {
// For other attributes, only remove newline chars.
processedValue = value.replace(/(\r\n|\n|\r)/gm, '');
processedValue = value.replace(/(\r\n|\n|\r)/gm, ' ');
}
keyValuePairs.push({
key,
Expand All @@ -409,6 +409,37 @@ export function getOpNodeAttrsKeyValuePairsForAttrsTable(
return keyValuePairs;
}

/**
* Gets the key value pairs for the given group node's attrs for attrs table.
*/
export function getGroupNodeAttrsKeyValuePairsForAttrsTable(
node: GroupNode,
modelGraph: ModelGraph,
filterRegex = '',
) {
const attrs =
modelGraph.groupNodeAttributes?.[node.id.replace('___group___', '')] || {};
const keyValuePairs: KeyValueList = [];
const regex = new RegExp(filterRegex, 'i');
for (const attrId of Object.keys(attrs)) {
const key = attrId;
const value = attrs[attrId];
const matchTargets = [`${key}:${value}`, `${key}=${value}`];
if (
filterRegex.trim() === '' ||
matchTargets.some((matchTarget) => regex.test(matchTarget))
) {
// Remove new line chars and spaces.
const processedValue = value.replace(/(\r\n|\n|\r)/gm, ' ');
keyValuePairs.push({
key,
value: processedValue,
});
}
}
return keyValuePairs;
}

/** Gets the key value pairs for the givn node's input for attrs table. */
export function getOpNodeInputsKeyValuePairsForAttrsTable(
node: OpNode,
Expand Down Expand Up @@ -591,7 +622,7 @@ export function getRegexMatchesForNode(
}
// Attribute.
if (shouldMatchTypes.has(SearchMatchType.ATTRIBUTE)) {
const attrs = getAttributesFromNode(node);
const attrs = getAttributesFromNode(node, modelGraph);
for (const attrId of Object.keys(attrs)) {
const value = attrs[attrId];
const text1 = `${attrId}:${value}`;
Expand Down Expand Up @@ -724,7 +755,10 @@ export function getRegexMatchesForNode(
}

/** Gets the attributes from the given node. */
export function getAttributesFromNode(node: ModelNode): KeyValuePairs {
export function getAttributesFromNode(
node: ModelNode,
modelGraph: ModelGraph,
): KeyValuePairs {
let attrs: KeyValuePairs = {};
if (isOpNode(node)) {
attrs = {...(node.attrs || {})};
Expand All @@ -735,6 +769,10 @@ export function getAttributesFromNode(node: ModelNode): KeyValuePairs {
'#descendants': `${(node.descendantsNodeIds || []).length}`,
'#children': `${(node.nsChildrenIds || []).length}`,
};
const customAttrs =
modelGraph.groupNodeAttributes?.[node.id.replace('___group___', '')] ||
{};
attrs = {...attrs, ...customAttrs};
}
return attrs;
}
Expand All @@ -749,7 +787,7 @@ export function getAttrValueRangeMatchesForNode(
): SearchMatch[] {
const matches: SearchMatch[] = [];

const attrs = getAttributesFromNode(node);
const attrs = getAttributesFromNode(node, modelGraph);
const value = attrs[attrName];
if (value != null) {
const numValue = Number(value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
[style.color]="textColor"
[class.has-bg-color]="hasBgColor"
[class.has-overflow]="hasOverflow"
[class.has-multiple-lines]="hasMultipleLines"
[class.expanded]="expanded"
(click)="handleToggleExpand($event, true)">
<div class="expanded-text">
Expand Down Expand Up @@ -62,7 +63,7 @@
} @else if (type === 'quantization') {
<div class="monospace-content">{{formatQuantization}}</div>
} @else {
{{text}}
<div class="text-content">{{text}}</div>
}
</div>
<div class="one-line-text" #oneLineText>
Expand Down
9 changes: 8 additions & 1 deletion src/ui/src/components/visualizer/expandable_info_text.scss
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
align-items: flex-start;
position: relative;

&.has-overflow {
&.has-overflow,
&.has-multiple-lines {
cursor: pointer;

.icon-container {
Expand Down Expand Up @@ -73,6 +74,12 @@
overflow: auto;
}

.text-content {
white-space: pre-wrap;
max-height: 500px;
overflow: auto;
}

.namespace-content {
display: flex;
flex-direction: column;
Expand Down
6 changes: 5 additions & 1 deletion src/ui/src/components/visualizer/expandable_info_text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ export class ExpandableInfoText implements AfterViewInit, OnDestroy, OnChanges {
}

handleToggleExpand(event: MouseEvent, fromExpandedText = false) {
if (!this.hasOverflow) {
if (!this.hasOverflow && !this.hasMultipleLines) {
return;
}

Expand All @@ -116,6 +116,10 @@ export class ExpandableInfoText implements AfterViewInit, OnDestroy, OnChanges {
return this.hasOverflowInternal;
}

get hasMultipleLines(): boolean {
return this.type !== 'namespace' && this.text.includes('\n');
}

get iconName(): string {
return this.expanded ? 'unfold_less' : 'unfold_more';
}
Expand Down
38 changes: 37 additions & 1 deletion src/ui/src/components/visualizer/info_panel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ enum SectionLabel {
GRAPH_INFO = 'Graph info',
NODE_INFO = 'Node info',
LAYER_INFO = 'Layer info',
LAYER_ATTRS = 'Layer attributes',
ATTRIBUTES = 'Attributes',
NODE_DATA_PROVIDERS = 'Node data providers',
IDENTICAL_GROUPS = 'Identical groups',
Expand Down Expand Up @@ -410,7 +411,9 @@ export class InfoPanel {
}

getSectionToggleIcon(sectionName: string): string {
return this.isSectionCollapsed(sectionName) ? 'expand_more' : 'expand_less';
return this.isSectionCollapsed(sectionName)
? 'chevron_right'
: 'expand_more';
}

handleLocateNode(nodeId: string, event: MouseEvent) {
Expand Down Expand Up @@ -647,6 +650,18 @@ export class InfoPanel {
value: String(layerCount),
},
);

// Custom attributes.
const graphAttributes = this.curModelGraph.groupNodeAttributes?.[''];
if (graphAttributes) {
for (const key of Object.keys(graphAttributes)) {
graphSection.items.push({
section: graphSection,
label: key,
value: graphAttributes[key],
});
}
}
}

private genInfoDataForSelectedOpNode() {
Expand Down Expand Up @@ -892,6 +907,27 @@ export class InfoPanel {
showOnNode: this.curShowOnGroupNodeInfoIds.has(label),
});

// Section for custom attributes.
const groupAttributes =
this.curModelGraph.groupNodeAttributes?.[
groupNode.id.replace('___group___', '')
];
if (groupAttributes) {
const attrsSection: InfoSection = {
label: SectionLabel.LAYER_ATTRS,
sectionType: 'group',
items: [],
};
this.sections.push(attrsSection);
for (const key of Object.keys(groupAttributes)) {
attrsSection.items.push({
section: nodeSection,
label: key,
value: groupAttributes[key],
});
}
}

// Section for identical groups.
if (groupNode.identicalGroupIndex != null) {
this.identicalGroupNodes = this.curModelGraph.nodes.filter(
Expand Down
4 changes: 2 additions & 2 deletions src/ui/src/components/visualizer/view_on_node.ng.html
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@
<div class="show-on-node-filter">
<input class="input-attrs-filter"
placeholder="Filter by regex" [disabled]="!item.selected"
[value]="curAttrsFilterText"
[value]="getAttrsFilterText(item)"
(keydown.enter)="input.blur()"
(input)="curAttrsFilterText = input.value"
(input)="setAttrsFilterText(item, input.value)"
(change)="handleAttrsFilterChanged(item)"
#input>
<div class="icon-container"
Expand Down
42 changes: 38 additions & 4 deletions src/ui/src/components/visualizer/view_on_node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ const ALL_SHOW_ON_NODE_ITEM_TYPES: ShowOnNodeItemType[] = [
ShowOnNodeItemType.OP_OUTPUTS,
ShowOnNodeItemType.LAYER_NODE_CHILDREN_COUNT,
ShowOnNodeItemType.LAYER_NODE_DESCENDANTS_COUNT,
ShowOnNodeItemType.LAYER_NODE_ATTRS,
];

const ALL_SHOW_ON_EDGE_ITEM_TYPES: ShowOnEdgeItemType[] = [
Expand Down Expand Up @@ -122,7 +123,8 @@ export class ViewOnNode {

showOnNodeItems: ShowOnNodeItem[] = [];
showOnEdgeItems: ShowOnEdgeItem[] = [];
curAttrsFilterText = '';
curOpAttrsFilterText = '';
curGroupAttrsFilterText = '';
opened = false;

constructor(
Expand Down Expand Up @@ -160,7 +162,12 @@ export class ViewOnNode {
item.filterRegex =
(curShowOnNodeItemTypes[this.rendererId] || {})[type]
?.filterRegex || '';
this.curAttrsFilterText = item.filterRegex;
this.curOpAttrsFilterText = item.filterRegex;
} else if (type === ShowOnNodeItemType.LAYER_NODE_ATTRS) {
item.filterRegex =
(curShowOnNodeItemTypes[this.rendererId] || {})[type]
?.filterRegex || '';
this.curGroupAttrsFilterText = item.filterRegex;
}
}

Expand Down Expand Up @@ -227,15 +234,42 @@ export class ViewOnNode {
this.paneId,
this.rendererId,
item.type,
this.curAttrsFilterText,
this.getAttrsFilterText(item),
);

// Save to local storage.
this.saveShowOnNodeItemsToLocalStorage();
}

getAttrsFilterText(item: ShowOnNodeItem): string {
switch (item.type) {
case ShowOnNodeItemType.OP_ATTRS:
return this.curOpAttrsFilterText;
case ShowOnNodeItemType.LAYER_NODE_ATTRS:
return this.curGroupAttrsFilterText;
default:
return '';
}
}

setAttrsFilterText(item: ShowOnNodeItem, text: string) {
switch (item.type) {
case ShowOnNodeItemType.OP_ATTRS:
this.curOpAttrsFilterText = text;
break;
case ShowOnNodeItemType.LAYER_NODE_ATTRS:
this.curGroupAttrsFilterText = text;
break;
default:
break;
}
}

getIsAttrs(item: ShowOnNodeItem): boolean {
return item.type === ShowOnNodeItemType.OP_ATTRS;
return (
item.type === ShowOnNodeItemType.OP_ATTRS ||
item.type === ShowOnNodeItemType.LAYER_NODE_ATTRS
);
}

private saveShowOnNodeItemsToLocalStorage() {
Expand Down
Loading

0 comments on commit d5edf3b

Please sign in to comment.