Leetcode 1339. Maximum Product of Splitted Binary Tree

7 January 2026
5 min read

Problem Description

Given the root of a binary tree, split the binary tree into two subtrees by removing one edge such that the product of the sums of the subtrees is maximized.

Return the maximum product of the sums of the two subtrees. Since the answer may be too large, return it modulo 10^9 + 7.

Note that you need to maximize the answer before taking the mod and not after taking it.

Solution Approach

The key insight is to:

1. First calculate the total sum of all nodes in the tree

2. For each possible edge removal (which splits the tree into two subtrees), calculate the sum of one subtree

3. The sum of the other subtree is simply total - subtreeSum

4. Track the maximum product: subtreeSum * (total - subtreeSum)

Design Intuition & Why This Approach

When solving this problem, there are several approaches we could consider:

Alternative Approach 1: Precalculate Subtree Sums

Idea: Calculate the sum of each subtree and store it somewhere for later use.

Why we don't do this:

1. Cannot modify TreeNode structure: We cannot store long values in the tree nodes themselves (the TreeNode class is predefined and only has int val).

2. Extra space overhead: Creating a separate data structure (like a HashMap) to store subtree sums would require O(n) extra space.

Alternative Approach 2: Create a New Tree with Subtree Sums

Idea: Build a new tree where each node contains the sum of its subtree.

Why we don't do this:

  • Extra space complexity: Creating a new tree would require O(n) additional space, which is unnecessary.
  • Our Optimal Approach: Calculate During Traversal

    Why this works best:

    1. Single pass efficiency: We already need to traverse the tree to check each edge. By calculating the subtree sum during this traversal, we maintain the same O(n) time complexity without any extra passes.

    2. No extra space: We don't need to store subtree sums anywhere - we calculate them on-the-fly and use them immediately to compute the product.

    3. Optimal time complexity: The time complexity remains O(n) because:

    - First pass: Calculate total sum (O(n))

    - Second pass: Calculate subtree sums and check products (O(n))

    - Total: O(n) - same as if we did it in one pass

    4. Space efficient: Only O(h) space for recursion stack, where h is the height of the tree.

    The key insight: Since we're already iterating through the tree to check each possible edge removal, we can calculate the subtree sum at each node during this iteration. This eliminates the need for precomputation or extra data structures, making our solution both time and space efficient.

    Code Solution

    java
    /**
     * Definition for a binary tree node.
     * public class TreeNode {
     *     int val;
     *     TreeNode left;
     *     TreeNode right;
     *     TreeNode() {}
     *     TreeNode(int val) { this.val = val; }
     *     TreeNode(int val, TreeNode left, TreeNode right) {
     *         this.val = val;
     *         this.left = left;
     *         this.right = right;
     *     }
     * }
     */
    class Solution {
    
        private static final int MOD = 1000000007;
        private long maxProduct = 0;
    
        public int maxProduct(TreeNode root) {
           TreeNode sumTree = root;
           long total = getSum(sumTree);
            getMaxProduct(root, total);
            return (int)(maxProduct%MOD);
        }
    
        long getMaxProduct(TreeNode root, long total){
            if(root == null){
                return 0;
            }
            long sumTreeSum = root.val + getMaxProduct(root.left,total) + getMaxProduct(root.right, total);
            if(sumTreeSum != total){
                long product = sumTreeSum * (total - sumTreeSum);
                maxProduct = Math.max(maxProduct, product);
            }
            return sumTreeSum;
        }
    
        long getSum( TreeNode root){
            if(root == null){
                return 0;
            }
            return root.val + getSum(root.left) + getSum(root.right);
        }
    }

    Detailed Explanation

    Step 1: Calculate Total Sum

    The getSum() method performs a post-order traversal to calculate the sum of all nodes in the binary tree. This gives us the total sum that we'll use to calculate the other subtree's sum when we split the tree.

    java
    long getSum(TreeNode root) {
        if(root == null) return 0;
        return root.val + getSum(root.left) + getSum(root.right);
    }

    Time Complexity: O(n) where n is the number of nodes

    Space Complexity: O(h) where h is the height of the tree (recursion stack)

    Step 2: Find Maximum Product

    The getMaxProduct() method is the core of the solution. It uses a post-order DFS traversal to efficiently calculate subtree sums and check products in a single pass:

    1. Calculate subtree sum on-the-fly: For each node, it calculates the sum of the subtree rooted at that node by recursively getting sums from left and right children, then adding the current node's value. This is done during the traversal - no precomputation needed.

    2. Check if it's a valid split: The condition if(sumTreeSum != total) ensures we don't consider the case where we're at the root (which would mean we haven't actually split the tree). When we remove an edge, we're essentially considering each subtree as a potential split.

    3. Calculate and track maximum product immediately: For each valid split point, we calculate:

    - subtreeSum: The sum of one subtree (the one we're currently at)

    - total - subtreeSum: The sum of the other subtree (the rest of the tree)

    - product = subtreeSum * (total - subtreeSum): The product we want to maximize

    We calculate this product right when we have the subtree sum, eliminating the need to store it for later.

    4. Return subtree sum for parent calculation: The method returns the sum of the current subtree, which is used by parent nodes to calculate their own subtree sums. This creates a bottom-up calculation flow where each node's subtree sum is computed from its children's sums.

    Why this is efficient: By combining the subtree sum calculation with the product checking in a single traversal, we avoid:

  • Storing subtree sums in a separate data structure (saves O(n) space)
  • Multiple tree traversals (saves time)
  • Modifying the tree structure (maintains immutability)
  • Why This Works

    When we traverse the tree using post-order DFS:

  • We visit each node after visiting its children
  • At each node, we know the sum of its left and right subtrees
  • Removing the edge above the current node would split the tree into:
  • - Subtree 1: The subtree rooted at the current node (sum = sumTreeSum)

    - Subtree 2: The rest of the tree (sum = total - sumTreeSum)

    The algorithm efficiently explores all possible split points in a single traversal.

    Key Observations

    1. Edge removal = subtree consideration: Each node (except root) represents a potential edge removal point. When we calculate a subtree sum, we're effectively considering what would happen if we removed the edge connecting this subtree to its parent.

    2. Root exclusion: The check sumTreeSum != total prevents considering the root as a split point, since removing an edge above the root doesn't make sense (the root has no parent edge).

    3. Modulo handling: We use long to avoid integer overflow during calculations, and only apply the modulo at the end when returning the result.

    Time & Space Complexity

  • Time Complexity: O(n) - We traverse each node exactly once in both `getSum()` and `getMaxProduct()`
  • Space Complexity: O(h) - Where h is the height of the tree, due to the recursion stack
  • Example Walkthrough

    Consider a tree:

    1
         / \
        2   3
       / \
      4   5

    1. Total sum = 1 + 2 + 3 + 4 + 5 = 15

    2. At node 2 (sum = 11): product = 11 * (15 - 11) = 11 * 4 = 44

    3. At node 3 (sum = 3): product = 3 * (15 - 3) = 3 * 12 = 36

    4. At node 4 (sum = 4): product = 4 * (15 - 4) = 4 * 11 = 44

    5. At node 5 (sum = 5): product = 5 * (15 - 5) = 5 * 10 = 50

    Maximum product = 50