Project

General

Profile

Download (11.4 KB) Statistics
| Branch: | Revision:

distorted-objectlib / src / main / java / org / distorted / objectlib / tablebases / PruningTable.java @ f2060de1

1
///////////////////////////////////////////////////////////////////////////////////////////////////
2
// Copyright 2023 Leszek Koltunski                                                               //
3
//                                                                                               //
4
// This file is part of Magic Cube.                                                              //
5
//                                                                                               //
6
// Magic Cube is proprietary software licensed under an EULA which you should have received      //
7
// along with the code. If not, check https://distorted.org/magic/License-Magic-Cube.html        //
8
///////////////////////////////////////////////////////////////////////////////////////////////////
9

    
10
package org.distorted.objectlib.tablebases;
11

    
12
///////////////////////////////////////////////////////////////////////////////////////////////////
13

    
14
class PruningTable
15
{
16
  static final int[] SUPPORTED = {4,8,12,16,20,24};
17

    
18
  private static final int HEADER_SIZE = 6;
19
  private byte[] mData;
20
  private int mNumBits;
21
  private int mBucketBytes;
22
  private int mNumBuckets;
23
  private int mLevel;
24
  private int[] mBucketOffsets;
25

    
26
///////////////////////////////////////////////////////////////////////////////////////////////////
27

    
28
  private static String printByte(byte b)
29
    {
30
    return String.format("%8s", Integer.toBinaryString(b & 0xFF)).replace(' ', '0');
31
    }
32

    
33
///////////////////////////////////////////////////////////////////////////////////////////////////
34

    
35
  private int numBytesForIndices(int total)
36
    {
37
    int bytes = 0;
38

    
39
    while(total!=0)
40
      {
41
      total>>=8;
42
      bytes++;
43
      }
44

    
45
    return bytes;
46
    }
47

    
48
///////////////////////////////////////////////////////////////////////////////////////////////////
49

    
50
  private void writeHeader()
51
    {
52
    mData[0] = (byte)mLevel;
53
    mData[1] = (byte)mBucketBytes;
54
    mData[2] = (byte)mNumBits;
55
    mData[3] = (byte)(mNumBuckets>>>16);
56
    mData[4] = (byte)(mNumBuckets>>> 8);
57
    mData[5] = (byte)(mNumBuckets&0xff);
58
    }
59

    
60
///////////////////////////////////////////////////////////////////////////////////////////////////
61

    
62
  private int getBucketOffset(int bucket)
63
    {
64
    if( mBucketOffsets[bucket]==0 )
65
      {
66
      if( bucket>=mNumBuckets ) mBucketOffsets[bucket] = mData.length;
67
      else
68
        {
69
        int ret = 0;
70
        int start = HEADER_SIZE + bucket*mBucketBytes;
71

    
72
        for(int i=0; i<mBucketBytes; i++)
73
          {
74
          ret<<=8;
75
          ret += ( ((int)(mData[start+i]))&0xff);
76
          }
77

    
78
        mBucketOffsets[bucket] = HEADER_SIZE+mNumBuckets*mBucketBytes+ret;
79
        }
80
      }
81

    
82
    return mBucketOffsets[bucket];
83
    }
84

    
85
///////////////////////////////////////////////////////////////////////////////////////////////////
86

    
87
  private void writeBucketPointers(int[] bucket)
88
    {
89
    int start=HEADER_SIZE;
90

    
91
    for( int b : bucket )
92
      {
93
      for( int i=0; i<mBucketBytes; i++)
94
        {
95
        mData[start+mBucketBytes-1-i] = (byte)(b%256);
96
        b>>=8;
97
        }
98
      start += mBucketBytes;
99
      }
100
    }
101

    
102
///////////////////////////////////////////////////////////////////////////////////////////////////
103

    
104
  private int retrieveData(int index, int entryInBucket)
105
    {
106
    int bits = entryInBucket*mNumBits;
107
    int bytes = index + bits/8;
108
    int ret=0;
109

    
110
    switch(mNumBits)
111
      {
112
      case  4: ret = (bits%8)!=0 ? (mData[bytes]&0xf) : ((mData[bytes]>>>4)&0xf);
113
               break;
114
      case  8: ret = (mData[bytes]&0xff);
115
               break;
116
      case 12: ret = (bits%8)!=0 ? (((mData[bytes]&0xf)<<8) + (mData[bytes+1]&0xff)) : (((mData[bytes]&0xff)<<4) + ((mData[bytes+1]>>>4)&0xf));
117
               break;
118
      case 16: ret = ((mData[bytes]&0xff)<<8) + (mData[bytes+1]&0xff);
119
               break;
120
      case 20: if( (bits%8) !=0 ) ret = (((mData[bytes]&0x0f)<<16) + ((mData[bytes+1]&0xff)<<8) + ( mData[bytes+2]     &0xff));
121
               else               ret = (((mData[bytes]&0xff)<<12) + ((mData[bytes+1]&0xff)<<4) + ((mData[bytes+2]>>>4)&0x0f));
122
               break;
123
      case 24: ret = ((mData[bytes]&0xff)<<16) + ((mData[bytes+1]&0xff)<<8) + (mData[bytes+2]&0xff);
124
               break;
125
      }
126

    
127
    return ret;
128
    }
129

    
130
///////////////////////////////////////////////////////////////////////////////////////////////////
131

    
132
  private void insertData(int index, int entryInBucket, int value)
133
    {
134
    int bits = entryInBucket*mNumBits;
135
    int bytes = index + bits/8;
136

    
137
    switch(mNumBits)
138
      {
139
      case  4: if( (bits%8) !=0 ) mData[bytes] = (byte)((mData[bytes]&(~0xf)) + value  );
140
               else               mData[bytes] = (byte)((value<<4) + (mData[bytes]&0xf));
141
               break;
142
      case  8: mData[bytes] = (byte)value;
143
               break;
144
      case 12: if( (bits%8) !=0 )
145
                 {
146
                 mData[bytes]  = (byte)((mData[bytes]&(~0xf)) + (value>>>8));
147
                 mData[bytes+1]= (byte)value;
148
                 }
149
               else
150
                 {
151
                 mData[bytes]  = (byte)(value>>>4);
152
                 mData[bytes+1]= (byte)((byte)(value<<4) + (mData[bytes+1]&0xf));
153
                 }
154
               break;
155
      case 16: mData[bytes]   = (byte)(value>>>8);
156
               mData[bytes+1] = (byte)(value&(0xff));
157
               break;
158
      case 20: if( (bits%8) !=0 )
159
                 {
160
                 mData[bytes]  = (byte)((mData[bytes]&(~0xf)) + (value>>>16));
161
                 mData[bytes+1]= (byte)(value>>>8);
162
                 mData[bytes+2]= (byte)(value&(0xff));
163
                 }
164
               else
165
                 {
166
                 mData[bytes]  = (byte)(value>>>12);
167
                 mData[bytes+1]= (byte)(value>>>4);
168
                 mData[bytes+2]= (byte)((byte)(value<<4) + (mData[bytes+2]&0xf));
169
                 }
170
               break;
171
      case 24: mData[bytes]   = (byte)(value>>>16);
172
               mData[bytes+1] = (byte)(value>>>8);
173
               mData[bytes+2] = (byte)(value&(0xff));
174
               break;
175
      }
176
    }
177

    
178
///////////////////////////////////////////////////////////////////////////////////////////////////
179

    
180
  private void writeBuckets(Tablebase table)
181
    {
182
    int mask = (1<<mNumBits);
183
    int prevBucket = -1;
184
    int entryInBucket = 0;
185
    int index = 0;
186
    int size = table.getSize();
187

    
188
    for(int i=0; i<size; i++)
189
      if( table.retrieveUnpacked(i)==mLevel )
190
        {
191
        int toBeWritten = (i%mask);
192
        int currBucket = (i>>mNumBits);
193

    
194
        if( currBucket!=prevBucket )
195
          {
196
          index = getBucketOffset(currBucket);
197
          entryInBucket = 0;
198
          }
199
        else
200
          {
201
          entryInBucket ++;
202
          }
203

    
204
        prevBucket = currBucket;
205
        insertData(index,entryInBucket,toBeWritten);
206
        }
207
    }
208

    
209
///////////////////////////////////////////////////////////////////////////////////////////////////
210

    
211
  private int retNumPositions(Tablebase table, int level)
212
    {
213
    int size=0, tableSize = table.getSize();
214

    
215
    for(int i=0; i<tableSize; i++)
216
      if( table.retrieveUnpacked(i)==level ) size++;
217

    
218
    return size;
219
    }
220

    
221
///////////////////////////////////////////////////////////////////////////////////////////////////
222

    
223
  private int computeApproximateSize(int tableSize, int positionSize, int numBits)
224
    {
225
    int entrySize = (1<<numBits);
226
    int numBuckets = (tableSize+entrySize-1)/entrySize;
227
    int overhang = (numBits==4 || numBits==12) ? numBuckets/4 : 0;
228
    int totalBytesForTable = (positionSize*numBits)/8 + overhang;
229
    int bucketBytes = numBytesForIndices(totalBytesForTable);
230
    int totalBytesForBuckets = numBuckets*bucketBytes;
231

    
232
    return HEADER_SIZE + totalBytesForBuckets + totalBytesForTable;
233
    }
234

    
235
///////////////////////////////////////////////////////////////////////////////////////////////////
236

    
237
  private void construct(Tablebase table, int level, int numBits)
238
    {
239
    mNumBits = numBits;
240
    mLevel = level;
241
    int size = table.getSize();
242
    int entrySize = (1<<numBits);
243
    mNumBuckets = (size+entrySize-1)/entrySize;
244
    int[] bucket = new int[mNumBuckets];
245

    
246
    mBucketOffsets = new int[mNumBuckets+1];
247

    
248
    for(int i=0; i<size; i++)
249
      if( table.retrieveUnpacked(i)==level )
250
        {
251
        int currBucket = i/entrySize;
252
        bucket[currBucket]++;
253
        }
254

    
255
    int totalBytesForTable = 0;
256

    
257
    for(int i=0; i<mNumBuckets; i++)
258
      {
259
      int numBytes = (bucket[i]*numBits+7)/8;
260
      bucket[i] = totalBytesForTable;
261
      totalBytesForTable+=numBytes;
262
      }
263

    
264
    mBucketBytes = numBytesForIndices(totalBytesForTable);
265
    int totalBytesForBuckets = mNumBuckets*mBucketBytes;
266
    int totalSize = HEADER_SIZE + totalBytesForBuckets + totalBytesForTable;
267
    mData = new byte[totalSize];
268

    
269
    android.util.Log.e("D", "Constructing a pruningTable, numBits="+numBits+" size "+totalSize);
270

    
271
    writeHeader();
272
    writeBucketPointers(bucket);
273
    writeBuckets(table);
274
    }
275

    
276
///////////////////////////////////////////////////////////////////////////////////////////////////
277
// PUBLIC API
278
///////////////////////////////////////////////////////////////////////////////////////////////////
279
// only numBits = 4,8,12,16,20,24 actually supported
280

    
281
  PruningTable(Tablebase table, int level, int numBits)
282
    {
283
    construct(table,level,numBits);
284
    }
285

    
286
///////////////////////////////////////////////////////////////////////////////////////////////////
287

    
288
  PruningTable(Tablebase table, int level)
289
    {
290
    int minimum = 0, chosen =  0;
291
    int size = retNumPositions(table,level);
292
    int tableSize = table.getSize();
293

    
294
    android.util.Log.e("D", "----- PRUNING TABLE LEVEL "+level+" ------");
295

    
296
    for(int i=0; i<SUPPORTED.length; i++)
297
      {
298
      int approx = computeApproximateSize( tableSize, size, SUPPORTED[i]);
299

    
300
      if( i==0 || approx<minimum )
301
        {
302
        minimum = approx;
303
        chosen  = i;
304
        }
305

    
306
      android.util.Log.e("D", "trying numBits: "+SUPPORTED[i]+" approx size: "+approx);
307
      }
308

    
309
    construct(table,level,SUPPORTED[chosen]);
310
    }
311

    
312
///////////////////////////////////////////////////////////////////////////////////////////////////
313

    
314
  PruningTable(byte[] data)
315
    {
316
    mData = data;
317

    
318
    mLevel       =  ((int)mData[0])&0xff;
319
    mBucketBytes =  ((int)mData[1])&0xff;
320
    mNumBits     =  ((int)mData[2])&0xff;
321
    mNumBuckets  =((((int)mData[3])&0xff)<<16) + ((((int)mData[4])&0xff)<<8) + (((int)mData[5])&0xff);
322

    
323
    mBucketOffsets = new int[mNumBuckets+1];
324
    }
325

    
326
///////////////////////////////////////////////////////////////////////////////////////////////////
327

    
328
  byte[] getPacked()
329
    {
330
    return mData;
331
    }
332

    
333
///////////////////////////////////////////////////////////////////////////////////////////////////
334

    
335
  int getLevel()
336
    {
337
    return mLevel;
338
    }
339

    
340
///////////////////////////////////////////////////////////////////////////////////////////////////
341

    
342
  boolean contains(int number)
343
    {
344
    int bucket = (number>>mNumBits);
345
    int offset1 = getBucketOffset(bucket);
346
    int offset2 = getBucketOffset(bucket+1);
347
    int numBytesInBucket = offset2-offset1;
348
    int numEntriesInBucket = (8*numBytesInBucket)/mNumBits;
349

    
350
    if( mNumBits==4 && numBytesInBucket>0 )
351
      {
352
      byte val = mData[offset2-1];
353
      if( (val/16)*16 == val ) numEntriesInBucket--;
354
      }
355

    
356
    int lower = 0;
357
    int upper = numEntriesInBucket-1;
358
    int rem = (bucket<<mNumBits);
359

    
360
    while( upper>=lower )
361
      {
362
      int mid = (lower+upper)/2;
363
      int value = retrieveData(offset1,mid) + rem;
364
      if( value==number ) return true;
365
      if( value>number ) upper = mid-1;
366
      else               lower = mid+1;
367
      }
368

    
369
    return false;
370
    }
371
}
(2-2/19)