Pinx0 Pinx0 - 5 months ago 13
SQL Question

Cluster data into subsets

I have a table with let's say 100 values that I want to classify in four groups based on minimizing the dispersion inside each group.

I have done this succesfully in excel with Solver minimizing the sum of the variance in each group, but I want to do it in SQL and I'm stuck in how to proceed.

The output I want is for each row that has an ID and a Value, get a third column with the proposed classificacion (1,2,3 or 4).

More graphically is basically determining the orange lines:

enter image description here

EDIT: Sample set and desired otuput:

http://sqlfiddle.com/#!9/d3736/1

Answer

Okay, I love this question, although I agree that using SQL to solve this is never going to be efficient.

My solution takes a brute-force approach. So basically I do this:

  • Sort the data into order, assigning an incremental number to each data point;
  • I then need to group the data into four classes, where I can assume that the lowest number will always be in class #1 and the highest number will always be in class #4. I produce a list of every combination within those bounds;
  • next I calculate the variance of each class, and the total variance, based on the way I split the data up;
  • finally I pop the results for each combination, in total variance order, so the top answer tells me how to achieve the lowest variance.

That is a horrible solution, as I am literally throwing around millions of data items to get the answer. It takes around 2-3 minutes to execute for your sample data. I imagine it wouldn't scale well at all, so if your real data is much larger than this then it wouldn't be appropriate at all.

However, I bet this could be optimised to some degree?

Also, I get a different answer to you. Your total variance for your proposed class assignment is 19.67926, but my best answer has a total variance of 18.95997.

Your answer is:

Class   Lower_Bound Upper_Bound
1       3.638   5.223
2       5.321   6.683
3       6.951   8.241
4       8.561   10.100

My answer is:

Class   Lower_Bound Upper_Bound
1       3.638   4.952
2       5.223   6.547
3       6.604   8.241
4       8.561   10.100

So similar, but some slight movement there.

Now for the horrid script, apologies for this, it isn't optimised at all, and I left all my working outs in...

--Grab the source data
DECLARE @Data TABLE (ID INT, Value NUMERIC(19,6), Class INT);
INSERT INTO @Data
VALUES
    (1000, 4.701, 1),
    (1001, 5.223, 1),
    (1002, 4.335, 1),
    (1003, 9.234, 4),
    (1004, 8.684, 4),
    (1005, 6.507, 2),
    (1006, 9.458, 4),
    (1007, 9.663, 4),
    (1011, 4.259, 1),
    (1012, 8.241, 3),
    (1013, 5.531, 2),
    (1014, 4.434, 1),
    (1015, 4.428, 1),
    (1016, 8.119, 3),
    (1017, 5.696, 2),
    (1018, 8.142, 3),
    (1019, 4.349, 1),
    (1020, 4.315, 1),
    (1021, 9.130, 4),
    (1023, 9.278, 4),
    (1024, 4.251, 1),
    (1027, 7.414, 3),
    (1028, 9.502, 4),
    (1032, 8.561, 4),
    (1033, 9.020, 4),
    (1034, 5.365, 2),
    (1037, 9.343, 4),
    (2000, 9.330, 4),
    (2001, 7.838, 3),
    (2002, 9.806, 4),
    (2003, 7.405, 3),
    (2004, 9.970, 4),
    (2008, 9.702, 4),
    (2009, 10.100, 4),
    (2010, 7.679, 3),
    (2011, 7.180, 3),
    (2012, 8.936, 4),
    (3000, 7.249, 3),
    (3001, 6.547, 2),
    (3002, 5.608, 2),
    (3003, 5.613, 2),
    (3004, 4.473, 1),
    (3005, 5.430, 2),
    (3007, 5.766, 2),
    (3009, 4.466, 1),
    (3011, 4.532, 1),
    (3012, 4.878, 1),
    (3013, 6.388, 2),
    (3014, 4.413, 1),
    (3015, 4.689, 1),
    (3016, 6.683, 2),
    (3017, 5.708, 2),
    (3018, 5.468, 2),
    (3020, 9.797, 4),
    (3022, 6.018, 2),
    (3027, 4.493, 1),
    (3031, 4.381, 1),
    (4001, 4.720, 1),
    (4002, 4.482, 1),
    (4003, 5.631, 2),
    (4004, 8.859, 4),
    (4005, 4.788, 1),
    (4006, 8.573, 4),
    (4007, 5.553, 2),
    (4008, 6.604, 2),
    (4009, 4.394, 1),
    (4010, 6.313, 2),
    (5000, 4.269, 1),
    (5001, 4.162, 1),
    (5002, 4.614, 1),
    (5003, 4.142, 1),
    (5004, 3.975, 1),
    (5005, 4.076, 1),
    (5007, 4.299, 1),
    (5008, 4.219, 1),
    (5009, 4.229, 1),
    (5010, 4.109, 1),
    (5011, 4.086, 1),
    (5012, 4.617, 1),
    (5013, 5.470, 2),
    (5014, 4.366, 1),
    (5015, 4.655, 1),
    (5017, 4.083, 1),
    (5018, 4.261, 1),
    (5019, 4.104, 1),
    (5020, 4.297, 1),
    (5021, 4.426, 1),
    (5022, 6.189, 2),
    (5023, 4.327, 1),
    (5024, 4.380, 1),
    (6000, 4.216, 1),
    (6001, 7.150, 3),
    (6002, 7.321, 3),
    (6003, 4.198, 1),
    (6004, 4.111, 1),
    (6005, 5.321, 2),
    (6006, 3.891, 1),
    (6007, 7.370, 3),
    (6008, 7.417, 3),
    (6009, 7.095, 3),
    (6010, 7.115, 3),
    (6011, 6.005, 2),
    (6012, 4.152, 1),
    (6013, 5.683, 2),
    (6014, 4.952, 1),
    (6015, 3.881, 1),
    (6016, 5.412, 2),
    (6017, 5.405, 2),
    (6018, 7.163, 3),
    (6019, 4.451, 1),
    (6020, 4.150, 1),
    (6021, 4.424, 1),
    (6022, 7.156, 3),
    (6024, 6.242, 2),
    (6025, 4.488, 1),
    (6026, 5.732, 2),
    (6027, 4.390, 1),
    (6028, 5.580, 2),
    (6029, 6.265, 2),
    (6032, 5.493, 2),
    (6033, 4.281, 1),
    (6034, 4.387, 1),
    (7000, 4.300, 1),
    (7001, 4.349, 1),
    (7002, 4.241, 1),
    (7003, 4.213, 1),
    (7004, 4.363, 1),
    (7005, 4.217, 1),
    (7006, 4.213, 1),
    (7008, 4.484, 1),
    (7009, 4.086, 1),
    (7010, 4.072, 1),
    (7011, 4.067, 1),
    (7012, 4.098, 1),
    (7013, 5.838, 2),
    (7015, 4.028, 1),
    (7016, 3.880, 1),
    (7021, 3.797, 1),
    (7022, 3.990, 1),
    (7023, 4.263, 1),
    (7024, 3.968, 1),
    (7026, 3.926, 1),
    (7030, 4.326, 1),
    (7031, 4.158, 1),
    (7032, 4.387, 1),
    (7033, 4.836, 1),
    (7034, 4.282, 1),
    (7035, 4.418, 1),
    (7036, 4.352, 1),
    (7037, 4.267, 1),
    (7038, 4.394, 1),
    (7039, 4.195, 1),
    (7040, 4.367, 1),
    (7042, 4.339, 1),
    (7043, 4.024, 1),
    (7044, 4.398, 1),
    (7045, 4.339, 1),
    (7046, 4.283, 1),
    (7047, 4.422, 1),
    (8000, 4.175, 1),
    (8001, 4.178, 1),
    (8002, 4.256, 1),
    (8003, 6.951, 3),
    (8004, 4.329, 1),
    (8007, 7.603, 3),
    (8008, 6.457, 2),
    (8011, 7.551, 3),
    (8012, 4.361, 1),
    (8014, 7.009, 3),
    (8015, 4.293, 1),
    (8016, 4.131, 1),
    (8017, 4.000, 1),
    (8019, 3.915, 1),
    (8022, 3.731, 1),
    (8023, 4.192, 1),
    (8024, 4.221, 1),
    (8025, 4.212, 1),
    (8028, 4.056, 1),
    (9001, 4.429, 1),
    (9002, 4.432, 1),
    (9003, 4.445, 1),
    (9004, 3.696, 1),
    (9005, 4.269, 1),
    (9010, 4.434, 1),
    (9011, 3.677, 1),
    (9016, 4.440, 1),
    (9017, 3.638, 1),
    (9018, 4.426, 1);

--Prove we can calculate the total variance from this
WITH Metrics AS (
    SELECT
        Class,
        SUM(Value) AS Class_Total,
        COUNT(*) AS Class_Items,
        SUM(Value) / COUNT(*) AS Mean_Value
    FROM
        @Data
    GROUP BY
        Class),
Variance AS (
    SELECT
        d.Class,
        m.Class_Items,
        MIN(d.Value) AS Lower_Bound,
        MAX(d.Value) AS Upper_Bound,
        SUM(POWER(d.Value - m.Mean_Value, 2)) AS Variance
    FROM
        @Data d
        INNER JOIN Metrics m ON m.Class = d.Class
    GROUP BY
        d.Class,
        m.Class_Items)
SELECT * FROM Variance;

--Brute force the classes
DROP TABLE #class;
DROP TABLE #data;
CREATE TABLE #class (
    iteration INT,
    threshold_1 INT,
    threshold_2 INT,
    threshold_3 INT);

--Organise the population into order
CREATE TABLE #data (
    data_order INT,
    data_value NUMERIC(19,6));
INSERT INTO
    #data
SELECT
    ROW_NUMBER() OVER (ORDER BY Value),
    value
FROM
    @Data

DECLARE @max_order INT;
SELECT @max_order = MAX(data_order) FROM #data;
SELECT @max_order;

--Set up the initial iteration
INSERT INTO 
    #class
SELECT
    1,
    2,
    3,
    4;

--Now use recursion to set up every other iteration to test
DROP TABLE #iterations;
WITH recursion AS (
    SELECT 
        iteration,
        threshold_1,
        threshold_2,
        threshold_3
    FROM 
        #class
    UNION ALL
    SELECT
        iteration + 1,
        CASE
            WHEN threshold_3 < @max_order - 1 OR threshold_2 < @max_order - 2 THEN threshold_1
            ELSE threshold_1 + 1
        END,
        CASE
            WHEN threshold_3 < @max_order - 1 THEN threshold_2
            WHEN threshold_2 < @max_order - 2 THEN threshold_2 + 1
            ELSE threshold_1 + 2
        END,
        CASE
            WHEN threshold_3 < @max_order - 1 THEN threshold_3 + 1
            WHEN threshold_2 < @max_order - 2 THEN threshold_2 + 2
            ELSE threshold_1 + 3
        END
    FROM
        recursion
    WHERE
        threshold_1 < @max_order - 3)
SELECT 
    *,
    CONVERT(NUMERIC(19,6), NULL) AS total_variance
INTO
    #iterations
FROM 
    recursion 
OPTION (MAXRECURSION 0);

--Now work over the set of iterations, calculating the total variance for each
DECLARE @iteration INT;
SELECT @iteration = ISNULL(MIN(iteration), 1) FROM #iterations WHERE total_variance IS NULL;

WHILE @iteration IS NOT NULL
BEGIN
    WITH ClassAssignment AS (
        SELECT
            d.*,
            CASE
                WHEN data_order < threshold_1 THEN 1
                WHEN data_order < threshold_2 THEN 2
                WHEN data_order < threshold_3 THEN 3
                ELSE 4
            END AS class
        FROM
            #data d
            CROSS JOIN #iterations i
        WHERE
            i.iteration = @iteration),
    Metrics AS (
        SELECT
            class,
            --SUM(data_value) AS class_total,
            --COUNT(*) AS class_items,
            SUM(data_value) / COUNT(*) AS mean_value
        FROM
            ClassAssignment
        GROUP BY
            class),
    Variance AS (
        SELECT
            d.class,
            SUM(POWER(d.data_value - m.mean_value, 2)) AS variance
        FROM
            ClassAssignment d
            INNER JOIN Metrics m ON m.class = d.class
        GROUP BY
            d.class),
    TotalVariance AS (
        SELECT SUM(Variance) AS total_variance FROM Variance)
    UPDATE
        i
    SET
        total_variance = v.total_variance
    FROM
        #iterations i
        CROSS JOIN TotalVariance v
    WHERE
        iteration = @iteration;
    SELECT @iteration = @iteration + 1;
    PRINT 'Iteration #' + CONVERT(VARCHAR(50), @iteration);
END;

--Try to do this in a set-based way
WITH ClassAssignment AS (
    SELECT
        i.iteration,
        d.*,
        CASE
            WHEN data_order < threshold_1 THEN 1
            WHEN data_order < threshold_2 THEN 2
            WHEN data_order < threshold_3 THEN 3
            ELSE 4
        END AS class
    FROM
        #data d
        CROSS JOIN #iterations i),
Metrics AS (
    SELECT
        iteration,
        class,
        SUM(data_value) / COUNT(*) AS mean_value
    FROM
        ClassAssignment
    GROUP BY
        class,
        iteration),
Variance AS (
    SELECT
        d.iteration,
        d.class,
        SUM(POWER(d.data_value - m.mean_value, 2)) AS variance
    FROM
        ClassAssignment d
        INNER JOIN Metrics m ON m.iteration = d.iteration AND m.class = d.class
    GROUP BY
        d.iteration,
        d.class),
TotalVariance AS (
    SELECT
        iteration,
        SUM(Variance) AS total_variance 
    FROM 
        Variance
    GROUP BY
        iteration)
UPDATE
    i
SET
    total_variance = v.total_variance
FROM
    #iterations i
    INNER JOIN TotalVariance v ON v.iteration = i.iteration;

--Results
SELECT TOP 1 * FROM #iterations WHERE total_variance IS NOT NULL ORDER BY total_variance;

--Best result is 116, 147, 170, with a total variance of 18.959971
--What is this in terms of numbers?
--Class #1
SELECT * FROM #data WHERE data_order = 1;
SELECT * FROM #data WHERE data_order = 115;
--Class #2
SELECT * FROM #data WHERE data_order = 116;
SELECT * FROM #data WHERE data_order = 146;
--Class #3
SELECT * FROM #data WHERE data_order = 147;
SELECT * FROM #data WHERE data_order = 169;
--Class #4
SELECT * FROM #data WHERE data_order = 170;
SELECT * FROM #data WHERE data_order = 188;