Dancing Linksを用いたKnuth's Algorithm Xによる「四角に切れ」ソルバーの実装

四角に切れ」とは以下のような問題である。

  • あるサイズの格子(将棋盤タイプ)が与えられる。
  • マスには数字が書かれているか、書かれていない。
  • 格子全体を長方形の集合に分割する。
  • 全ての長方形にちょうど1つずつ数字が入り、かつ長方形の面積がその数字と等しい。

この問題は以下のようにして、Exact Cover Problemに帰着できる。

  • マス全体の集合を「全体集合」とおく。
  • ある数字を含み、面積がその数字と等しく、他の数字を含まないような長方形をピースとし、ピース全体の集合を「部分集合の集合」とおく。

これを、Knuth's Algorithm XとDancing Linksの解説 - TopCoderとJ言語と時々F#で解説されているKnuth's Algorithm X で解いた。

入力形式

ここではパソコン甲子園版の入力ではなく、以下のような入力形式とした。

3 2
2 0 0
0 4 0

プログラム

#include <cstdio>
#include <cstdlib>
#include <vector>
#include <climits>
using namespace std;

int w,h;
vector<vector<int> > table;
#define xyc(x,y) ((x)*h+(y))

vector<vector<int> > pieces;

struct dlx_node;
struct dlx_node {
    int xy, pieceid;
    dlx_node *l,*r,*u,*d;
    dlx_node() {
        reset();
    }
    void reset() {
        l = r = u = d = this;
        xy = -1;
        pieceid = -1;
    }
};

dlx_node head;
vector<dlx_node> piece_nodes;
vector<int> xy_count;
vector<dlx_node> xy_nodes;
vector<dlx_node> nodes;
vector<int> result;

void search() {
    if(head.l == &head && head.d == &head) {
        vector<int> keta(w, 1);
        for(int y = 0; y < h; y++) {
            for(int x = 0; x < w; x++) {
                if(pieces[result[xyc(x,y)]].size()>=10)keta[x]=2;
            }
        }
        for(int y = 0; y < h; y++) {
            for(int x = 0; x < w; x++) {
                printf(keta[x]==2 ? "%2d " : "%1d ", pieces[result[xyc(x,y)]].size());
            }
            printf("\n");
        }
        printf("\n\n");
        return;
    }
    if(head.l == &head)return;
    if(head.d == &head)return;
    dlx_node *min_xy_node = NULL;
    int min_xy_count = INT_MAX;
    for(dlx_node *i = head.r; i != &head; i=i->r) {
        if(xy_count[i->xy] < min_xy_count) {
            min_xy_count = xy_count[i->xy];
            min_xy_node = i;
        }
    }
    if(min_xy_count==0) return;
    for(dlx_node *i = min_xy_node->d; i != min_xy_node; i=i->d) {
        vector<int>& v = pieces[i->pieceid];
        for(int j = 0; j < (int)v.size(); j++) {
            result[v[j]] = i->pieceid;
        }

        dlx_node *ih = &piece_nodes[i->pieceid];
        vector<int> kh_stk;
        for(dlx_node *j = ih->r; j != ih; j=j->r) {
            dlx_node *jh = &xy_nodes[j->xy];
            for(dlx_node *k = jh->d; k != jh; k=k->d) {
                kh_stk.push_back(k->pieceid);
                dlx_node *kh = &piece_nodes[k->pieceid];
                for(dlx_node *l = kh->r; l != kh; l=l->r) {
                    l->u->d = l->d;
                    l->d->u = l->u;
                    xy_count[l->xy]--;
                }
                kh->u->d = kh->d;
                kh->d->u = kh->u;
            }
        }
        vector<int> jh_stk;
        for(dlx_node *j = ih->r; j != ih; j=j->r) {
            jh_stk.push_back(j->xy);
            dlx_node *jh = &xy_nodes[j->xy];
            jh->r->l = jh->l;
            jh->l->r = jh->r;
        }
        search();
        while(!jh_stk.empty()) {
            dlx_node *jh = &xy_nodes[jh_stk.back()];
            jh->r->l = jh;
            jh->l->r = jh;
            jh_stk.pop_back();
        }
        while(!kh_stk.empty()) {
            dlx_node *kh = &piece_nodes[kh_stk.back()];
                for(dlx_node *l = kh->r; l != kh; l=l->r) {
                    l->u->d = l;
                    l->d->u = l;
                    xy_count[l->xy]++;
                }
                kh->u->d = kh;
                kh->d->u = kh;
            kh_stk.pop_back();
        }
    }
}

int main(int argc, char **argv)
{
    scanf("%d%d", &w, &h);
    table = vector<vector<int> >(w, vector<int>(h));
    for(int y = 0; y < h; y++) {
        for(int x = 0; x < w; x++) {
            scanf("%d", &table[x][y]);
        }
    }
    {
        printf("input data:\n");
        vector<int> keta(w, 1);
        for(int y = 0; y < h; y++) {
            for(int x = 0; x < w; x++) {
                if(table[x][y]>=10)keta[x]=2;
            }
        }
        for(int y = 0; y < h; y++) {
            for(int x = 0; x < w; x++) {
                printf(keta[x]==2 ? "%2d " : "%1d ", table[x][y]);
            }
            printf("\n");
        }
        printf("\n\n");
    }
    int nodesize = 0;
    int piecesize_sum = 0;
    for(int cx = 0; cx < w; cx++) {
        for(int cy = 0; cy < h; cy++) {
            int piece_size = table[cx][cy];
            piecesize_sum += piece_size;
            if(piece_size == 0)continue;
            for(int pw = 1; pw <= piece_size; pw++) {
                int ph = piece_size / pw;
                if(piece_size > ph*pw)continue;
                for(int sx = max(cx+1-pw,0); sx <= min(cx,w-pw); sx++) {
                    for(int sy = max(cy+1-ph,0); sy <= min(cy,h-ph); sy++) {
                        int sum = 0;
                        for(int x = sx; x < sx+pw; x++) {
                            for(int y = sy; y < sy+ph; y++) {
                                sum += table[x][y];
                            }
                        }
                        if(sum!=piece_size)continue;
                        pieces.push_back(vector<int>());
                        for(int x = sx; x < sx+pw; x++) {
                            for(int y = sy; y < sy+ph; y++) {
                                pieces.back().push_back(xyc(x,y));
                                nodesize++;
                            }
                        }
                    }
                }
            }
        }
    }
    if(piecesize_sum != w*h) {
        fprintf(stderr, "error: piece size sum(%d) != table size(%dx%d)\n", piecesize_sum, w, h);
        exit(1);
    }
    piece_nodes.resize(pieces.size());
    xy_nodes.resize(w*h);
    xy_count.resize(w*h);
    nodes.resize(nodesize);
    result.resize(w*h);
    for(int i = 0; i < (int)pieces.size(); i++) {
        piece_nodes[i].reset();
        piece_nodes[i].pieceid = i;
    }
    for(int i = 0; i < w*h; i++) {
        xy_nodes[i].reset();
        xy_nodes[i].xy = i;
    }
    for(int i = 1; i < (int)pieces.size(); i++) {
        piece_nodes[i].u = &piece_nodes[i-1];
        piece_nodes[i-1].d = &piece_nodes[i];
    }
    piece_nodes[0].pieceid = 0;
    head.d = &piece_nodes[0];
    head.u = &piece_nodes[pieces.size()-1];
    piece_nodes[0].u = &head;
    piece_nodes[pieces.size()-1].d = &head;

    for(int i = 1; i < w*h; i++) {
        xy_nodes[i].l = &xy_nodes[i-1];
        xy_nodes[i-1].r = &xy_nodes[i];
    }
    head.r = &xy_nodes[0];
    head.l = &xy_nodes[w*h-1];
    xy_nodes[0].l = &head;
    xy_nodes[w*h-1].r = &head;

    int nodecnt = 0;
    for(int i = 0; i < (int)pieces.size(); i++) {
        for(int j = 0; j < (int)pieces[i].size(); j++) {
            int jj = pieces[i][j];
            dlx_node *node = &nodes[nodecnt++];
            xy_nodes[jj].u->d = node;
            node->u = xy_nodes[jj].u;
            xy_nodes[jj].u = node;
            node->d = &xy_nodes[jj];

            piece_nodes[i].l->r = node;
            node->l = piece_nodes[i].l;
            piece_nodes[i].l = node;
            node->r = &piece_nodes[i];

            node->xy = jj;
            node->pieceid = i;
            xy_count[jj]++;
        }
    }
    search();
    return 0;
}