Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to pin a single op node to the top of a layer #116

Merged
merged 1 commit into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/ui/src/components/visualizer/common/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ export const TENSOR_VALUES_KEY = '__value';
/** The key to store the tensor tag in i/o metadata. */
export const TENSOR_TAG_METADATA_KEY = '__tensor_tag';

/** The margin for the left and right side of the layout. */
export const LAYOUT_MARGIN_X = 20;

/** A map from color names to the corresponding hex color. */
export const COLOR_NAME_TO_HEX: Record<string, string> = {
'aliceblue': '#f0f8ff',
Expand Down
4 changes: 4 additions & 0 deletions src/ui/src/components/visualizer/common/input_graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/

import {
GraphNodeConfig,
GraphNodeStyle,
IncomingEdge,
KeyValueList,
Expand Down Expand Up @@ -136,4 +137,7 @@ export declare interface GraphNode {

/** The default style of the node. */
style?: GraphNodeStyle;

/** Custom configs for the node. */
config?: GraphNodeConfig;
}
7 changes: 7 additions & 0 deletions src/ui/src/components/visualizer/common/model_graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/

import {
GraphNodeConfig,
GraphNodeStyle,
IncomingEdge,
KeyValuePairs,
Expand Down Expand Up @@ -206,6 +207,9 @@ export declare interface OpNode extends ModelNodeBase {

/** The style of the node. */
style?: GraphNodeStyle;

/** Custom configs for the node. */
config?: GraphNodeConfig;
}

/**
Expand Down Expand Up @@ -237,6 +241,9 @@ export declare interface GroupNode extends ModelNodeBase {
* nodes to layout.
*/
sectionContainer?: boolean;

/** The op node that should be pinned to the top of the group. */
pinToTopOpNode?: OpNode;
}

/** A node in a model graph. */
Expand Down
6 changes: 6 additions & 0 deletions src/ui/src/components/visualizer/common/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ export declare interface GraphNodeStyle {
hoveredBorderColor?: string;
}

/** Custom configs for a graph node. */
export declare interface GraphNodeConfig {
/** Whether to pin the node to the top of the group it belongs to. */
pinToGroupTop?: boolean;
}

/** Data to pass along when clicking "open in popup" on a group node. */
export interface PopupPanelData {
id: string;
Expand Down
27 changes: 26 additions & 1 deletion src/ui/src/components/visualizer/webgl_renderer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import * as three from 'three';
import {AppService} from './app_service';
import {
GLOBAL_KEY,
LAYOUT_MARGIN_X,
NODE_LABEL_HEIGHT,
WEBGL_ELEMENT_Y_FACTOR,
} from './common/consts';
Expand Down Expand Up @@ -260,6 +261,7 @@ export class WebglRenderer implements OnInit, OnDestroy {
readonly GROUP_NODE_BORDER_COLOR = new THREE.Color('#aaa');
readonly GROUP_NODE_LABEL_SEPARATOR_COLOR = new THREE.Color('#DADCE0');
readonly GROUP_NODE_ICON_COLOR = new THREE.Color('#444746');
readonly GROUP_NODE_PIN_TO_TOP_SEPARATOR_COLOR = new THREE.Color('#bbb');
readonly EDGE_COLOR = new THREE.Color(
this.appService.config()?.edgeColor || '#aaa',
);
Expand Down Expand Up @@ -1776,7 +1778,7 @@ export class WebglRenderer implements OnInit, OnDestroy {
}
nodeBodyRectangles.push({
id: node.id,
index: i,
index: nodeBodyRectangles.length,
bound: {
x: x + width / 2,
y: y + height / 2,
Expand All @@ -1797,6 +1799,29 @@ export class WebglRenderer implements OnInit, OnDestroy {
bgColor.b === 1,
});

// Render separator between the pinned node and the rest of the nodes.
if (isGroupNode(node) && node.expanded && node.pinToTopOpNode) {
nodeBodyRectangles.push({
id: `${node.id}_pin_to_top_separator`,
index: nodeBodyRectangles.length,
bound: {
x: x + width / 2,
y:
(node.pinToTopOpNode.globalY || 0) +
(node.pinToTopOpNode.height || 0) / 2 +
12.5,
width: width - LAYOUT_MARGIN_X * 2,
height: 1,
},
yOffset: WEBGL_ELEMENT_Y_FACTOR * nodeIndex + 0.1,
isRounded: true,
borderColor: this.GROUP_NODE_PIN_TO_TOP_SEPARATOR_COLOR,
bgColor: this.GROUP_NODE_PIN_TO_TOP_SEPARATOR_COLOR,
borderWidth: 1,
opacity: 1,
});
}

// Subgraph indicators.
if (isOpNode(node) && node.subgraphIds) {
const indicatorWidth = SUBGRAPH_INDICATOR_SIZE;
Expand Down
56 changes: 51 additions & 5 deletions src/ui/src/components/visualizer/worker/graph_expander.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
* ==============================================================================
*/

import {GroupNode, ModelGraph} from '../common/model_graph';
import {LAYOUT_MARGIN_X} from '../common/consts';
import {GroupNode, ModelGraph, OpNode} from '../common/model_graph';
import {NodeDataProviderRunData, ShowOnNodeItemData} from '../common/types';
import {getDeepestExpandedGroupNodeIds, isGroupNode} from '../common/utils';

Expand All @@ -25,7 +26,6 @@ import {
GraphLayout,
LAYOUT_MARGIN_BOTTOM,
LAYOUT_MARGIN_TOP,
LAYOUT_MARGIN_X,
getNodeHeight,
getNodeWidth,
} from './graph_layout';
Expand Down Expand Up @@ -84,8 +84,13 @@ export class GraphExpander {

// Grow size.
const curTargetWidth = rect.width + LAYOUT_MARGIN_X * 2;
const curTargetHeight =
let curTargetHeight =
rect.height + LAYOUT_MARGIN_TOP + LAYOUT_MARGIN_BOTTOM;
if (curGroupNode.pinToTopOpNode) {
curTargetHeight += this.getPinToTopNodeVerticalSpace(
curGroupNode.pinToTopOpNode,
);
}
curGroupNode.width = curTargetWidth;
curGroupNode.height = curTargetHeight;

Expand Down Expand Up @@ -158,8 +163,13 @@ export class GraphExpander {

// Grow size.
const curTargetWidth = rect.width + LAYOUT_MARGIN_X * 2;
const curTargetHeight =
let curTargetHeight =
rect.height + LAYOUT_MARGIN_TOP + LAYOUT_MARGIN_BOTTOM;
if (groupNode.pinToTopOpNode) {
curTargetHeight += this.getPinToTopNodeVerticalSpace(
groupNode.pinToTopOpNode,
);
}
groupNode.width = curTargetWidth;
groupNode.height = curTargetHeight;
}
Expand Down Expand Up @@ -263,8 +273,13 @@ export class GraphExpander {

// Shrink size.
const curTargetWidth = rect.width + LAYOUT_MARGIN_X * 2;
const curTargetHeight =
let curTargetHeight =
rect.height + LAYOUT_MARGIN_TOP + LAYOUT_MARGIN_BOTTOM;
if (curGroupNode.pinToTopOpNode) {
curTargetHeight += this.getPinToTopNodeVerticalSpace(
curGroupNode.pinToTopOpNode,
);
}
curGroupNode.width = curTargetWidth;
curGroupNode.height = curTargetHeight;

Expand Down Expand Up @@ -401,6 +416,33 @@ export class GraphExpander {
(groupNode.y || 0) +
(groupNode.globalY || 0) +
(node.localOffsetY || 0);

// Move the node down if the current group node has a node pinned to
// top.
if (
groupNode.pinToTopOpNode &&
node.id !== groupNode.pinToTopOpNode.id
) {
node.globalY += this.getPinToTopNodeVerticalSpace(
groupNode.pinToTopOpNode,
);
}

// For the pinned-to-top node, move it to the top-middle of the group
// node.
if (groupNode.pinToTopOpNode?.id === node.id) {
node.globalX =
(groupNode.x || 0) +
(groupNode.globalX || 0) +
(groupNode.width || 0) / 2;
node.globalY =
(groupNode.y || 0) +
(groupNode.globalY || 0) +
(node.localOffsetY || 0) +
this.getPinToTopNodeVerticalSpace(node as OpNode) -
(node.height || 0) / 2 +
10;
}
}
if (isGroupNode(node)) {
this.updateNodeOffset(node);
Expand Down Expand Up @@ -434,4 +476,8 @@ export class GraphExpander {
}
}
}

private getPinToTopNodeVerticalSpace(node: OpNode): number {
return (node.height || 0) + 20;
}
}
39 changes: 29 additions & 10 deletions src/ui/src/components/visualizer/worker/graph_layout.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/

import {
LAYOUT_MARGIN_X,
MAX_IO_ROWS_IN_ATTRS_TABLE,
NODE_ATTRS_TABLE_FONT_SIZE,
NODE_ATTRS_TABLE_LABEL_VALUE_PADDING,
Expand All @@ -34,6 +35,7 @@ import {
OpNode,
} from '../common/model_graph';
import {
GraphNodeConfig,
KeyValueList,
NodeDataProviderRunData,
Point,
Expand All @@ -57,9 +59,6 @@ import {

import {Dagre, DagreGraphInstance} from './dagre_types';

/** The margin for the left and right side of the layout. */
export const LAYOUT_MARGIN_X = 20;

/** The margin for the top and bottom side of the layout. */
export const LAYOUT_MARGIN_TOP = 36;

Expand All @@ -85,6 +84,7 @@ export declare interface DagreNode {
height: number;
x?: number;
y?: number;
config?: GraphNodeConfig;
}

interface LayoutGraph {
Expand Down Expand Up @@ -143,7 +143,11 @@ export class GraphLayout {

// Set nodes/edges to dagre.
for (const id of Object.keys(layoutGraph.nodes)) {
this.dagreGraph.setNode(id, layoutGraph.nodes[id]);
const dagreNode = layoutGraph.nodes[id];
if (dagreNode.config?.pinToGroupTop) {
continue;
}
this.dagreGraph.setNode(id, dagreNode);
}
for (const fromNodeId of Object.keys(layoutGraph.outgoingEdges)) {
for (const toNodeId of layoutGraph.outgoingEdges[fromNodeId]) {
Expand All @@ -154,7 +158,8 @@ export class GraphLayout {
// Run the layout algorithm.
this.dagre.layout(this.dagreGraph);

// Set the results back to the original model nodes.
// Set the results back to the original model nodes and calculate the bound
// that contains all the nodes.
let minX = Number.MAX_VALUE;
let minY = Number.MAX_VALUE;
let maxX = Number.NEGATIVE_INFINITY;
Expand All @@ -172,13 +177,17 @@ export class GraphLayout {
node.localOffsetX = 0;
node.localOffsetY = 0;

minX = Math.min(minX, node.x);
minY = Math.min(minY, node.y);
maxX = Math.max(maxX, node.x + node.width);
maxY = Math.max(maxY, node.y + node.height);
// Don't consider the bound of the node if it's pinned to the top of the
// group.
if (!dagreNode.config?.pinToGroupTop) {
minX = Math.min(minX, node.x);
minY = Math.min(minY, node.y);
maxX = Math.max(maxX, node.x + node.width);
maxY = Math.max(maxY, node.y + node.height);
}
}

// Edges.
// Expand the bound to include all the edges.
let minEdgeX = Number.MAX_VALUE;
let minEdgeY = Number.MAX_VALUE;
let maxEdgeX = Number.NEGATIVE_INFINITY;
Expand Down Expand Up @@ -511,6 +520,7 @@ export function getLayoutGraph(
nodeDataProviderRuns,
testMode,
),
config: isOpNode(node) ? node.config : undefined,
};
layoutGraph.nodes[node.id] = dagreNode;
}
Expand All @@ -520,6 +530,15 @@ export function getLayoutGraph(
modelGraph.layoutGraphEdges[rootGroupNodeId] || {};
for (const [fromNodeId, toNodeIds] of Object.entries(curLayoutGraphEdges)) {
for (const toNodeId of Object.keys(toNodeIds)) {
// Ignore edges from/to nodes pinned to group top.
const fromNode = modelGraph.nodesById[fromNodeId];
const toNode = modelGraph.nodesById[toNodeId];
if (fromNode && isOpNode(fromNode) && fromNode.config?.pinToGroupTop) {
continue;
}
if (toNode && isOpNode(toNode) && toNode.config?.pinToGroupTop) {
continue;
}
addLayoutGraphEdge(layoutGraph, fromNodeId, toNodeId);
}
}
Expand Down
6 changes: 6 additions & 0 deletions src/ui/src/components/visualizer/worker/graph_processor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ export class GraphProcessor {
if (graphNode.style) {
opNode.style = graphNode.style;
}
if (graphNode.config) {
opNode.config = graphNode.config;
}
modelGraph.nodes.push(opNode);
modelGraph.nodesById[opNode.id] = opNode;

Expand Down Expand Up @@ -290,6 +293,9 @@ export class GraphProcessor {
}
if (!parentGroupNode.nsChildrenIds.includes(node.id)) {
parentGroupNode.nsChildrenIds.push(node.id);
if (isOpNode(node) && node.config?.pinToGroupTop) {
parentGroupNode.pinToTopOpNode = node;
}
}
}
}
Expand Down
Loading