diff --git a/lib/mix/tasks/sql.gen.parser.ex b/lib/mix/tasks/sql.gen.parser.ex index 5cc76cd..d845d4e 100644 --- a/lib/mix/tasks/sql.gen.parser.ex +++ b/lib/mix/tasks/sql.gen.parser.ex @@ -156,6 +156,9 @@ defmodule Mix.Tasks.Sql.Gen.Parser do def insert_node({:with = tag, meta, []}, [{:ident, _, _} = l, {:parens, _, _} = r, {:as = t2, m2, a}], [], context, root) do {[], [], context, root ++ [{tag, meta, [{t2, m2, [[l, r] | a]}]}]} end + def insert_node({:with = tag, meta, []}, unit, acc, context, root) do + {[], [], context, root ++ [{tag, meta, unit ++ acc}]} + end def insert_node({tag, meta, []}, unit, acc, context, root) when tag in ~w[by in references]a do {[{tag, meta, predicate(unit ++ acc)}], [], context, root} end diff --git a/lib/parser.ex b/lib/parser.ex index 7ae34f5..9b854a3 100644 --- a/lib/parser.ex +++ b/lib/parser.ex @@ -126,6 +126,9 @@ defmodule SQL.Parser do def insert_node({:with = tag, meta, []}, [{:ident, _, _} = l, {:parens, _, _} = r, {:as = t2, m2, a}], [], context, root) do {[], [], context, root ++ [{tag, meta, [{t2, m2, [[l, r] | a]}]}]} end + def insert_node({:with = tag, meta, []}, unit, acc, context, root) do + {[], [], context, root ++ [{tag, meta, unit ++ acc}]} + end def insert_node({tag, meta, []}, unit, acc, context, root) when tag in ~w[by in references]a do {[{tag, meta, predicate(unit ++ acc)}], [], context, root} end diff --git a/test/sql_test.exs b/test/sql_test.exs index c4f95cc..67c5a95 100644 --- a/test/sql_test.exs +++ b/test/sql_test.exs @@ -142,6 +142,72 @@ defmodule SQLTest do test "regular" do assert "with temp (n, fact) as (select 0, 1 union all select n + 1, (n + 1) * fact from temp where n < 9)" == to_string(~SQL[with temp (n, fact) as (select 0, 1 union all select n+1, (n+1)*fact from temp where n < 9)]) end + + test "complex with" do + sql = ~SQL[ + with customer_rankings as( + select customer_id, + sum(amount) as total_spent, + rank() over(order by sum(amount) desc) as spending_rank + from transactions + group by customer_id + ), + top_customers as( + select c.customer_id, + c.name, + cr.total_spent, + cr.spending_rank + from customer_rankings cr + join customers c on c.customer_id = cr.customer_id + where cr.spending_rank <= 10 + ) + select tc.name, + tc.total_spent, + tc.spending_rank + from top_customers tc + order by tc.spending_rank + ] + + output = to_string(sql) + assert output == "with customer_rankings as(select customer_id, sum(amount) as total_spent, rank() over(order by sum(amount) desc) as spending_rank from transactions group by customer_id), top_customers as(select c.customer_id, c.name, cr.total_spent, cr.spending_rank from customer_rankings cr join customers c on c.customer_id = cr.customer_id where cr.spending_rank <= 10) select tc.name, tc.total_spent, tc.spending_rank from top_customers tc order by tc.spending_rank" + end + + test "complex with multiple ctes" do + sql = ~SQL[ + with customer_rankings as ( + select + customer_id, + sum(amount) as total_spent, + rank() over (order by sum(amount) desc) as spending_rank + from transactions + group by customer_id + ), + top_customers as ( + select + c.customer_id, + c.name, + cr.total_spent, + cr.spending_rank + from customer_rankings cr + join customers c on c.customer_id = cr.customer_id + where cr.spending_rank <= 10 + ) + select + tc.name, + tc.total_spent, + tc.spending_rank, + case + when tc.total_spent > tc.avg_amount * 2 then 'High Value' + when tc.total_spent > tc.avg_amount then 'Medium Value' + else 'Low Value' + end as customer_segment + from top_customers tc + order by tc.spending_rank, tc.month + ] + + output = to_string(sql) + assert output == "with customer_rankings as(select customer_id, sum(amount) as total_spent, rank() over(order by sum(amount) desc) as spending_rank from transactions group by customer_id), top_customers as(select c.customer_id, c.name, cr.total_spent, cr.spending_rank from customer_rankings cr join customers c on c.customer_id = cr.customer_id where cr.spending_rank <= 10) select tc.name, tc.total_spent, tc.spending_rank, case when tc.total_spent > tc.avg_amount * 2 then 'High Value' when tc.total_spent > tc.avg_amount then 'Medium Value' else 'Low Value' end as customer_segment from top_customers tc order by tc.spending_rank, tc.month" + end end describe "combinations" do