Leetcode 1339. Maximum Product of Splitted Binary Tree
Problem Description
Given the root of a binary tree, split the binary tree into two subtrees by removing one edge such that the product of the sums of the subtrees is maximized.
Return the maximum product of the sums of the two subtrees. Since the answer may be too large, return it modulo 10^9 + 7.
Note that you need to maximize the answer before taking the mod and not after taking it.
Solution Approach
The key insight is to:
1. First calculate the total sum of all nodes in the tree
2. For each possible edge removal (which splits the tree into two subtrees), calculate the sum of one subtree
3. The sum of the other subtree is simply total - subtreeSum
4. Track the maximum product: subtreeSum * (total - subtreeSum)
Design Intuition & Why This Approach
When solving this problem, there are several approaches we could consider:
Alternative Approach 1: Precalculate Subtree Sums
Idea: Calculate the sum of each subtree and store it somewhere for later use.
Why we don't do this:
1. Cannot modify TreeNode structure: We cannot store long values in the tree nodes themselves (the TreeNode class is predefined and only has int val).
2. Extra space overhead: Creating a separate data structure (like a HashMap) to store subtree sums would require O(n) extra space.
Alternative Approach 2: Create a New Tree with Subtree Sums
Idea: Build a new tree where each node contains the sum of its subtree.
Why we don't do this:
Our Optimal Approach: Calculate During Traversal
Why this works best:
1. Single pass efficiency: We already need to traverse the tree to check each edge. By calculating the subtree sum during this traversal, we maintain the same O(n) time complexity without any extra passes.
2. No extra space: We don't need to store subtree sums anywhere - we calculate them on-the-fly and use them immediately to compute the product.
3. Optimal time complexity: The time complexity remains O(n) because:
- First pass: Calculate total sum (O(n))
- Second pass: Calculate subtree sums and check products (O(n))
- Total: O(n) - same as if we did it in one pass
4. Space efficient: Only O(h) space for recursion stack, where h is the height of the tree.
The key insight: Since we're already iterating through the tree to check each possible edge removal, we can calculate the subtree sum at each node during this iteration. This eliminates the need for precomputation or extra data structures, making our solution both time and space efficient.
Code Solution
/**
* Definition for a binary tree node.
* public class TreeNode {
* int val;
* TreeNode left;
* TreeNode right;
* TreeNode() {}
* TreeNode(int val) { this.val = val; }
* TreeNode(int val, TreeNode left, TreeNode right) {
* this.val = val;
* this.left = left;
* this.right = right;
* }
* }
*/
class Solution {
private static final int MOD = 1000000007;
private long maxProduct = 0;
public int maxProduct(TreeNode root) {
TreeNode sumTree = root;
long total = getSum(sumTree);
getMaxProduct(root, total);
return (int)(maxProduct%MOD);
}
long getMaxProduct(TreeNode root, long total){
if(root == null){
return 0;
}
long sumTreeSum = root.val + getMaxProduct(root.left,total) + getMaxProduct(root.right, total);
if(sumTreeSum != total){
long product = sumTreeSum * (total - sumTreeSum);
maxProduct = Math.max(maxProduct, product);
}
return sumTreeSum;
}
long getSum( TreeNode root){
if(root == null){
return 0;
}
return root.val + getSum(root.left) + getSum(root.right);
}
}Detailed Explanation
Step 1: Calculate Total Sum
The getSum() method performs a post-order traversal to calculate the sum of all nodes in the binary tree. This gives us the total sum that we'll use to calculate the other subtree's sum when we split the tree.
long getSum(TreeNode root) {
if(root == null) return 0;
return root.val + getSum(root.left) + getSum(root.right);
}Time Complexity: O(n) where n is the number of nodes
Space Complexity: O(h) where h is the height of the tree (recursion stack)
Step 2: Find Maximum Product
The getMaxProduct() method is the core of the solution. It uses a post-order DFS traversal to efficiently calculate subtree sums and check products in a single pass:
1. Calculate subtree sum on-the-fly: For each node, it calculates the sum of the subtree rooted at that node by recursively getting sums from left and right children, then adding the current node's value. This is done during the traversal - no precomputation needed.
2. Check if it's a valid split: The condition if(sumTreeSum != total) ensures we don't consider the case where we're at the root (which would mean we haven't actually split the tree). When we remove an edge, we're essentially considering each subtree as a potential split.
3. Calculate and track maximum product immediately: For each valid split point, we calculate:
- subtreeSum: The sum of one subtree (the one we're currently at)
- total - subtreeSum: The sum of the other subtree (the rest of the tree)
- product = subtreeSum * (total - subtreeSum): The product we want to maximize
We calculate this product right when we have the subtree sum, eliminating the need to store it for later.
4. Return subtree sum for parent calculation: The method returns the sum of the current subtree, which is used by parent nodes to calculate their own subtree sums. This creates a bottom-up calculation flow where each node's subtree sum is computed from its children's sums.
Why this is efficient: By combining the subtree sum calculation with the product checking in a single traversal, we avoid:
Why This Works
When we traverse the tree using post-order DFS:
- Subtree 1: The subtree rooted at the current node (sum = sumTreeSum)
- Subtree 2: The rest of the tree (sum = total - sumTreeSum)
The algorithm efficiently explores all possible split points in a single traversal.
Key Observations
1. Edge removal = subtree consideration: Each node (except root) represents a potential edge removal point. When we calculate a subtree sum, we're effectively considering what would happen if we removed the edge connecting this subtree to its parent.
2. Root exclusion: The check sumTreeSum != total prevents considering the root as a split point, since removing an edge above the root doesn't make sense (the root has no parent edge).
3. Modulo handling: We use long to avoid integer overflow during calculations, and only apply the modulo at the end when returning the result.
Time & Space Complexity
Example Walkthrough
Consider a tree:
1
/ \
2 3
/ \
4 51. Total sum = 1 + 2 + 3 + 4 + 5 = 15
2. At node 2 (sum = 11): product = 11 * (15 - 11) = 11 * 4 = 44
3. At node 3 (sum = 3): product = 3 * (15 - 3) = 3 * 12 = 36
4. At node 4 (sum = 4): product = 4 * (15 - 4) = 4 * 11 = 44
5. At node 5 (sum = 5): product = 5 * (15 - 5) = 5 * 10 = 50
Maximum product = 50