This is a DP problem on trees and I tried to solve it using the following approach.
Store a DP state DP[current][mode] = min_cost
for sub-tree. mode is equal to 0 if the parent of the current node has not bought a family ticket and 1 if it has.
The recurrence is:
if(mode == 0)
DP[current][mode] = min(cost_of_single_ticket + sum of DP[child][0] for all children, cost_of_family_ticket + sum of DP[child][1] for all children)
else
DP[current][mode] = min(sum of DP[child][0] for all children, cost_of_family_ticket + sum of DP[child][1] for all children)
The algorithm seems right but I am getting a wrong answer. Here is my code but it is very long and confusing. If you have already solved the problem could you provide some tricky test cases?
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.StringTokenizer;
public class Main
{
private static MyScanner sc;
private static PrintWriter out;
private static int single_cost;
private static int family_cost;
private static HashMap<String, Integer> map;
private static ArrayList<String>[] Tree;
private static HashMap<String, Integer> times;
private static State[][] DP;
public static void main(String[] args)
{
sc = new MyScanner();
out = new PrintWriter(System.out);
single_cost = sc.nextInt();
family_cost = sc.nextInt();
int counter = 1;
times = new HashMap();
map = new HashMap();
HashMap<String, ArrayList<String>> data = new HashMap();
int unique_id = 0;
data = new HashMap();
boolean flag = false;
while(!flag)
{
if(single_cost == 0 && family_cost == 0) break;
String k = sc.nextLine();
if(k.length() != 0)
{
String[] line = k.split("\\s+");
if(isNumber(line[0]))
{
int nu_s = Integer.parseInt(line[0]);
int nu_f = Integer.parseInt(line[1]);
if(nu_s == 0 && nu_f == 0)
{
DP = new State[map.size()][2];
Tree = new ArrayList[map.size()];
for(int i = 0; i < map.size(); i++) Tree[i] = new ArrayList();
for(String parent : data.keySet())
{
Tree[map.get(parent)] = new ArrayList();
for(String child : data.get(parent)) Tree[map.get(parent)].add(child);
}
for(String parent : data.keySet())
{
//System.out.println("Adding " + parent);
times.put(parent, 1);
}
for(String parent : data.keySet())
{
for(String child : data.get(parent))
{
if(times.containsKey(child)) times.put(child, times.get(child) + 1);
}
}
ArrayList<String> roots = get_root();
// System.out.println(root + " is the root.");
int s = 0;
int f = 0;
int c = 0;
for(String root : roots)
{
State t = solver(root, 0);
s += t.single_qty;
f += t.family_qty;
c += t.cost();
}
out.println((counter++) + ". " + s + " " + f + " " + c);
flag = true;
}
else
{
DP = new State[map.size()][2];
Tree = new ArrayList[map.size()];
for(int i = 0; i < map.size(); i++) Tree[i] = new ArrayList();
for(String parent : data.keySet())
{
Tree[map.get(parent)] = new ArrayList();
for(String child : data.get(parent)) Tree[map.get(parent)].add(child);
}
for(String parent : data.keySet())
{
//System.out.println("Adding " + parent);
times.put(parent, 1);
}
for(String parent : data.keySet())
{
for(String child : data.get(parent))
{
if(times.containsKey(child)) times.put(child, times.get(child) + 1);
}
}
ArrayList<String> roots = get_root();
// System.out.println(root + " is the root.");
int s = 0;
int f = 0;
int c = 0;
for(String root : roots)
{
State t = solver(root, 0);
s += t.single_qty;
f += t.family_qty;
c += t.cost();
}
out.println((counter++) + ". " + s + " " + f + " " + c);
// new
data = new HashMap();
times = new HashMap();
map = new HashMap();
unique_id = 0;
single_cost = nu_s;
family_cost = nu_f;
}
}
else
{
if(!map.containsKey(line[0])) map.put(line[0], unique_id++);
if(!data.containsKey(line[0])) data.put(line[0], new ArrayList());
for(int i = 1; i < line.length; i++)
{
// System.out.println("Parent -->" + line[0]);
if(!map.containsKey(line[i])) map.put(line[i], unique_id++);
data.get(line[0]).add(line[i]);
// System.out.println("Child " + i + " -->" + line[i]);
}
}
}
}
out.close();
}
private static ArrayList<String> get_root()
{
ArrayList<String> L = new ArrayList();
for(String node : times.keySet())
{
if(times.get(node) == 1) L.add(node);
}
return L;
}
private static boolean isNumber(String a)
{
try
{
int k = Integer.parseInt(a);
return true;
}
catch(Exception e) {return false;}
}
private static State solver(String current, int mode)
{
// System.out.println(current + " --> " + mode);
if(Tree[map.get(current)].isEmpty())
{
if(mode == 0)
{
State t = new State();
t.single_qty += 1;
// System.out.println("At " + current + " and sending " + t.cost()+" one single qty up mode 0");
return t;
}
else
{
// System.out.println("At " + current + " and sending 0 one single qty up mode 1");
return new State();
}
}
else
{
if(DP[map.get(current)][mode] != null) return DP[map.get(current)][mode];
else
{
if(mode == 0)
{
State curr_state_one = new State();
curr_state_one.single_qty += 1;
State t;
for(String child : Tree[map.get(current)])
{
t = solver(child, 0);
curr_state_one.single_qty += t.single_qty;
curr_state_one.family_qty += t.family_qty;
}
State curr_state_two = new State();
curr_state_two.family_qty += 1;
for(String child : Tree[map.get(current)])
{
t = solver(child, 1);
curr_state_two.single_qty += t.single_qty;
curr_state_two.family_qty += t.family_qty;
}
DP[map.get(current)][mode] = curr_state_one.minimum(curr_state_two);
// System.out.println("At " + current + " and sending "+ DP[map.get(current)][mode].cost() +" one single qty up mode 0");
return DP[map.get(current)][mode];
}
else
{
State curr_state_one = new State();
State t;
for(String child : Tree[map.get(current)])
{
t = solver(child, 0);
curr_state_one.single_qty += t.single_qty;
curr_state_one.family_qty += t.family_qty;
}
State curr_state_two = new State();
curr_state_two.family_qty += 1;
for(String child : Tree[map.get(current)])
{
t = solver(child, 1);
curr_state_two.single_qty += t.single_qty;
curr_state_two.family_qty += t.family_qty;
}
DP[map.get(current)][mode] = curr_state_one.minimum(curr_state_two);
// System.out.println("At " + current + " and sending "+ DP[map.get(current)][mode].cost() +" one single qty up mode 1");
return DP[map.get(current)][mode];
}
}
}
}
private static int max(int a, int b)
{
if(a > b) return a;
else return b;
}
private static int min(int a, int b)
{
if(a < b) return a;
else return b;
}
private static class State
{
public int single_qty;
public int family_qty;
public State()
{
single_qty = 0;
family_qty = 0;
}
public int cost()
{
return (single_qty * single_cost) + (family_qty * family_cost);
}
public State minimum(State t)
{
if(this.cost() < t.cost()) return this;
else return t;
}
public void data_out(int index)
{
out.println(index + ". " + this.single_qty + " " + this.family_qty + " " + this.cost());
}
}
public static class MyScanner
{
BufferedReader br;
StringTokenizer st;
public MyScanner()
{
br = new BufferedReader(new InputStreamReader(System.in));
}
String next()
{
while (st == null || !st.hasMoreElements())
{
try
{
st = new StringTokenizer(br.readLine());
} catch (IOException e)
{
e.printStackTrace();
}
}
return st.nextToken();
}
int nextInt()
{
return Integer.parseInt(next());
}
long nextLong()
{
return Long.parseLong(next());
}
double nextDouble()
{
return Double.parseDouble(next());
}
String nextLine()
{
String str = "";
try
{
str = br.readLine();
} catch (IOException e)
{
e.printStackTrace();
}
return str;
}
}
}