Project

General

Profile

Download (7.35 KB) Statistics
| Branch: | Revision:

distorted-objectlib / src / main / java / org / distorted / objectlib / tablebases / PruningTable.java @ 5e30b196

1
///////////////////////////////////////////////////////////////////////////////////////////////////
2
// Copyright 2023 Leszek Koltunski                                                               //
3
//                                                                                               //
4
// This file is part of Magic Cube.                                                              //
5
//                                                                                               //
6
// Magic Cube is proprietary software licensed under an EULA which you should have received      //
7
// along with the code. If not, check https://distorted.org/magic/License-Magic-Cube.html        //
8
///////////////////////////////////////////////////////////////////////////////////////////////////
9

    
10
package org.distorted.objectlib.tablebases;
11

    
12
///////////////////////////////////////////////////////////////////////////////////////////////////
13

    
14
class PruningTable
15
{
16
  private static final int HEADER_SIZE = 5;
17
  private final byte[] mData;
18
  private final int mNumBits;
19
  private final int mBucketBytes;
20
  private final int mTotalSize;
21
  private final int mNumBuckets;
22
  private final int mLevel;
23

    
24
///////////////////////////////////////////////////////////////////////////////////////////////////
25

    
26
  private static String printByte(byte b)
27
    {
28
    return String.format("%8s", Integer.toBinaryString(b & 0xFF)).replace(' ', '0');
29
    }
30

    
31
///////////////////////////////////////////////////////////////////////////////////////////////////
32

    
33
  private int numBytesForIndices(int total)
34
    {
35
    int bytes = 0;
36

    
37
    while(total!=0)
38
      {
39
      total>>=8;
40
      bytes++;
41
      }
42

    
43
    return bytes;
44
    }
45

    
46
///////////////////////////////////////////////////////////////////////////////////////////////////
47

    
48
  private void writeHeader()
49
    {
50
    mData[0] = (byte)mLevel;
51
    mData[1] = (byte)mBucketBytes;
52
    mData[2] = (byte)mNumBits;
53
    mData[3] = (byte)(mNumBuckets/256);
54
    mData[4] = (byte)(mNumBuckets%256);
55
    }
56

    
57
///////////////////////////////////////////////////////////////////////////////////////////////////
58

    
59
  private int getBucketOffset(int bucket)
60
    {
61
    if( bucket>=mNumBuckets ) return mTotalSize;
62

    
63
    int ret = 0;
64
    int start = HEADER_SIZE + bucket*mBucketBytes;
65

    
66
    for(int i=0; i<mBucketBytes; i++)
67
      {
68
      ret<<=8;
69
      ret += ( ((int)(mData[start+i]))&0xff);
70
      }
71

    
72
    return HEADER_SIZE+mNumBuckets*mBucketBytes+ret;
73
    }
74

    
75
///////////////////////////////////////////////////////////////////////////////////////////////////
76

    
77
  private void writeBucketPointers(int[] bucket)
78
    {
79
    int start=HEADER_SIZE;
80

    
81
    for( int b : bucket )
82
      {
83
      for( int i=0; i<mBucketBytes; i++)
84
        {
85
        mData[start+mBucketBytes-1-i] = (byte)(b%256);
86
        b>>=8;
87
        }
88
      start += mBucketBytes;
89
      }
90
    }
91

    
92
///////////////////////////////////////////////////////////////////////////////////////////////////
93

    
94
  private int retrieveData(int index, int entryInBucket)
95
    {
96
    int bits = entryInBucket*mNumBits;
97
    int bytes = index + bits/8;
98
    int ret=0;
99

    
100
    switch(mNumBits)
101
      {
102
      case  4: ret = (bits%8)!=0 ? (mData[bytes]&0xf) : ((mData[bytes]>>>4)&0xf);
103
               break;
104
      case  8: ret = (mData[bytes]&0xff);
105
               break;
106
      case 12: ret = (bits%8)!=0 ? ((mData[bytes]&0xf)*256 + (mData[bytes+1]&0xff)) : ( (mData[bytes]&0xff)*16 + ((mData[bytes+1]>>>4)&0xf));
107
               break;
108
      case 16: ret = (mData[bytes]&0xff)*256 + (mData[bytes+1]&0xff);
109
               break;
110
      }
111

    
112
    return ret;
113
    }
114

    
115
///////////////////////////////////////////////////////////////////////////////////////////////////
116

    
117
  private void insertData(int index, int entryInBucket, int value)
118
    {
119
    int bits = entryInBucket*mNumBits;
120
    int bytes = index + bits/8;
121

    
122
    switch(mNumBits)
123
      {
124
      case  4: if( (bits%8) !=0 ) mData[bytes] = (byte)((mData[bytes]&(~0xf)) + value  );
125
               else               mData[bytes] = (byte)((value<<4) + (mData[bytes]&0xf));
126
               break;
127
      case  8: mData[bytes] = (byte)value;
128
               break;
129
      case 12: if( (bits%8) !=0 )
130
                 {
131
                 mData[bytes]  = (byte)((mData[bytes]&(~0xf)) + (value>>>8));
132
                 mData[bytes+1]= (byte)value;
133
                 }
134
               else
135
                 {
136
                 mData[bytes]  = (byte)(value>>>4);
137
                 mData[bytes+1]= (byte)((byte)(value<<4) + (mData[bytes+1]&( 0xf)));
138
                 }
139
               break;
140
      case 16: mData[bytes]   = (byte)(value>>>8);
141
               mData[bytes+1] = (byte)(value&(0xff));
142
      }
143
    }
144

    
145
///////////////////////////////////////////////////////////////////////////////////////////////////
146

    
147
  private void writeBuckets(Tablebase table)
148
    {
149
    int mask = (1<<mNumBits);
150
    int prevBucket = -1;
151
    int entryInBucket = 0;
152
    int index = 0;
153
    int size = table.getSize();
154

    
155
    for(int i=0; i<size; i++)
156
      if( table.retrieveUnpacked(i)==mLevel )
157
        {
158
        int toBeWritten = (i%mask);
159
        int currBucket = (i>>mNumBits);
160

    
161
        if( currBucket!=prevBucket )
162
          {
163
          index = getBucketOffset(currBucket);
164
          entryInBucket = 0;
165
          }
166
        else
167
          {
168
          entryInBucket ++;
169
          }
170

    
171
        prevBucket = currBucket;
172
        insertData(index,entryInBucket,toBeWritten);
173
        }
174
    }
175

    
176
///////////////////////////////////////////////////////////////////////////////////////////////////
177
// PUBLIC API
178
///////////////////////////////////////////////////////////////////////////////////////////////////
179
// only numBits = 4,8,12,16 actually supported
180

    
181
  public PruningTable(Tablebase table, int level, int numBits)
182
    {
183
    mNumBits = numBits;
184
    mLevel = level;
185
    int size = table.getSize();
186
    int entrySize = (1<<numBits);
187
    mNumBuckets = (size+entrySize-1)/entrySize;
188
    int[] bucket = new int[mNumBuckets];
189

    
190
    for(int i=0; i<size; i++)
191
      if( table.retrieveUnpacked(i)==level )
192
        {
193
        int currBucket = i/entrySize;
194
        bucket[currBucket]++;
195
        }
196

    
197
    int totalBytesForTable = 0;
198

    
199
    for(int i=0; i<mNumBuckets; i++)
200
      {
201
      int numBytes = (bucket[i]*numBits+7)/8;
202
      bucket[i] = totalBytesForTable;
203
      totalBytesForTable+=numBytes;
204
      }
205

    
206
    mBucketBytes = numBytesForIndices(totalBytesForTable);
207
    int totalBytesForBuckets = mNumBuckets*mBucketBytes;
208
    mTotalSize = HEADER_SIZE + totalBytesForBuckets + totalBytesForTable;
209
    mData = new byte[mTotalSize];
210

    
211
    writeHeader();
212
    writeBucketPointers(bucket);
213
    writeBuckets(table);
214
    }
215

    
216
///////////////////////////////////////////////////////////////////////////////////////////////////
217

    
218
  public boolean belongs(int number)
219
    {
220
    int bucket = (number>>mNumBits);
221
    int offset1 = getBucketOffset(bucket);
222
    int offset2 = getBucketOffset(bucket+1);
223
    int numBytesInBucket = offset2-offset1;
224
    int numEntriesInBucket = (8*numBytesInBucket)/mNumBits;
225

    
226
    if( mNumBits==4 && numBytesInBucket>0 )
227
      {
228
      byte val = mData[offset2-1];
229
      if( (val/16)*16 == val ) numEntriesInBucket--;
230
      }
231

    
232
    int lower = 0;
233
    int upper = numEntriesInBucket-1;
234

    
235
    while( upper>=lower )
236
      {
237
      int mid = (lower+upper)/2;
238
      int value = retrieveData(offset1,mid);
239

    
240
      value += (bucket<<mNumBits);
241
      if( value==number ) return true;
242

    
243
      if( value>number ) upper = mid-1;
244
      else               lower = mid+1;
245
      }
246

    
247
    return false;
248
    }
249
}
(2-2/11)