Project

General

Profile

Download (11.2 KB) Statistics
| Branch: | Revision:

distorted-objectlib / src / main / java / org / distorted / objectlib / tablebases / PruningTable.java @ b2eb9a1d

1
///////////////////////////////////////////////////////////////////////////////////////////////////
2
// Copyright 2023 Leszek Koltunski                                                               //
3
//                                                                                               //
4
// This file is part of Magic Cube.                                                              //
5
//                                                                                               //
6
// Magic Cube is proprietary software licensed under an EULA which you should have received      //
7
// along with the code. If not, check https://distorted.org/magic/License-Magic-Cube.html        //
8
///////////////////////////////////////////////////////////////////////////////////////////////////
9

    
10
package org.distorted.objectlib.tablebases;
11

    
12
///////////////////////////////////////////////////////////////////////////////////////////////////
13

    
14
class PruningTable
15
{
16
  static final int[] SUPPORTED = {4,8,12,16,20,24};
17

    
18
  private static final int HEADER_SIZE = 6;
19
  private byte[] mData;
20
  private int mNumBits;
21
  private int mBucketBytes;
22
  private int mNumBuckets;
23
  private int mLevel;
24
  private int[] mBucketOffsets;
25

    
26
///////////////////////////////////////////////////////////////////////////////////////////////////
27

    
28
  private static String printByte(byte b)
29
    {
30
    return String.format("%8s", Integer.toBinaryString(b & 0xFF)).replace(' ', '0');
31
    }
32

    
33
///////////////////////////////////////////////////////////////////////////////////////////////////
34

    
35
  private int numBytesForIndices(int total)
36
    {
37
    int bytes = 0;
38

    
39
    while(total!=0)
40
      {
41
      total>>=8;
42
      bytes++;
43
      }
44

    
45
    return bytes;
46
    }
47

    
48
///////////////////////////////////////////////////////////////////////////////////////////////////
49

    
50
  private void writeHeader()
51
    {
52
    mData[0] = (byte)mLevel;
53
    mData[1] = (byte)mBucketBytes;
54
    mData[2] = (byte)mNumBits;
55
    mData[3] = (byte)(mNumBuckets>>>16);
56
    mData[4] = (byte)(mNumBuckets>>> 8);
57
    mData[5] = (byte)(mNumBuckets&0xff);
58
    }
59

    
60
///////////////////////////////////////////////////////////////////////////////////////////////////
61

    
62
  private int getBucketOffset(int bucket)
63
    {
64
    if( mBucketOffsets[bucket]==0 )
65
      {
66
      if( bucket>=mNumBuckets ) mBucketOffsets[bucket] = mData.length;
67
      else
68
        {
69
        int ret = 0;
70
        int start = HEADER_SIZE + bucket*mBucketBytes;
71

    
72
        for(int i=0; i<mBucketBytes; i++)
73
          {
74
          ret<<=8;
75
          ret += ( ((int)(mData[start+i]))&0xff);
76
          }
77

    
78
        mBucketOffsets[bucket] = HEADER_SIZE+mNumBuckets*mBucketBytes+ret;
79
        }
80
      }
81

    
82
    return mBucketOffsets[bucket];
83
    }
84

    
85
///////////////////////////////////////////////////////////////////////////////////////////////////
86

    
87
  private void writeBucketPointers(int[] bucket)
88
    {
89
    int start=HEADER_SIZE;
90

    
91
    for( int b : bucket )
92
      {
93
      for( int i=0; i<mBucketBytes; i++)
94
        {
95
        mData[start+mBucketBytes-1-i] = (byte)(b%256);
96
        b>>=8;
97
        }
98
      start += mBucketBytes;
99
      }
100
    }
101

    
102
///////////////////////////////////////////////////////////////////////////////////////////////////
103

    
104
  private int retrieveData(int index, int entryInBucket)
105
    {
106
    int bits = entryInBucket*mNumBits;
107
    int bytes = index + bits/8;
108

    
109
    switch(mNumBits)
110
      {
111
      case  4: return (bits%8)!=0 ? (mData[bytes]&0xf) : ((mData[bytes]>>>4)&0xf);
112
      case  8: return (mData[bytes]&0xff);
113
      case 12: return (bits%8)!=0 ? (((mData[bytes]&0xf)<<8) + (mData[bytes+1]&0xff)) : (((mData[bytes]&0xff)<<4) + ((mData[bytes+1]>>>4)&0xf));
114
      case 16: return ((mData[bytes]&0xff)<<8) + (mData[bytes+1]&0xff);
115
      case 20: if( (bits%8) !=0 ) return (((mData[bytes]&0x0f)<<16) + ((mData[bytes+1]&0xff)<<8) + ( mData[bytes+2]     &0xff));
116
               else               return (((mData[bytes]&0xff)<<12) + ((mData[bytes+1]&0xff)<<4) + ((mData[bytes+2]>>>4)&0x0f));
117
      case 24: return ((mData[bytes]&0xff)<<16) + ((mData[bytes+1]&0xff)<<8) + (mData[bytes+2]&0xff);
118
      }
119

    
120
    return 0;
121
    }
122

    
123
///////////////////////////////////////////////////////////////////////////////////////////////////
124

    
125
  private void insertData(int index, int entryInBucket, int value)
126
    {
127
    int bits = entryInBucket*mNumBits;
128
    int bytes = index + bits/8;
129

    
130
    switch(mNumBits)
131
      {
132
      case  4: if( (bits%8) !=0 ) mData[bytes] = (byte)((mData[bytes]&(~0xf)) + value  );
133
               else               mData[bytes] = (byte)((value<<4) + (mData[bytes]&0xf));
134
               break;
135
      case  8: mData[bytes] = (byte)value;
136
               break;
137
      case 12: if( (bits%8) !=0 )
138
                 {
139
                 mData[bytes]  = (byte)((mData[bytes]&(~0xf)) + (value>>>8));
140
                 mData[bytes+1]= (byte)value;
141
                 }
142
               else
143
                 {
144
                 mData[bytes]  = (byte)(value>>>4);
145
                 mData[bytes+1]= (byte)((byte)(value<<4) + (mData[bytes+1]&0xf));
146
                 }
147
               break;
148
      case 16: mData[bytes]   = (byte)(value>>>8);
149
               mData[bytes+1] = (byte)(value&(0xff));
150
               break;
151
      case 20: if( (bits%8) !=0 )
152
                 {
153
                 mData[bytes]  = (byte)((mData[bytes]&(~0xf)) + (value>>>16));
154
                 mData[bytes+1]= (byte)(value>>>8);
155
                 mData[bytes+2]= (byte)(value&(0xff));
156
                 }
157
               else
158
                 {
159
                 mData[bytes]  = (byte)(value>>>12);
160
                 mData[bytes+1]= (byte)(value>>>4);
161
                 mData[bytes+2]= (byte)((byte)(value<<4) + (mData[bytes+2]&0xf));
162
                 }
163
               break;
164
      case 24: mData[bytes]   = (byte)(value>>>16);
165
               mData[bytes+1] = (byte)(value>>>8);
166
               mData[bytes+2] = (byte)(value&(0xff));
167
               break;
168
      }
169
    }
170

    
171
///////////////////////////////////////////////////////////////////////////////////////////////////
172

    
173
  private void writeBuckets(Tablebase table)
174
    {
175
    int mask = (1<<mNumBits);
176
    int prevBucket = -1;
177
    int entryInBucket = 0;
178
    int index = 0;
179
    int size = table.getSize();
180

    
181
    for(int i=0; i<size; i++)
182
      if( table.retrieveUnpacked(i)==mLevel )
183
        {
184
        int toBeWritten = (i%mask);
185
        int currBucket = (i>>mNumBits);
186

    
187
        if( currBucket!=prevBucket )
188
          {
189
          index = getBucketOffset(currBucket);
190
          entryInBucket = 0;
191
          }
192
        else
193
          {
194
          entryInBucket ++;
195
          }
196

    
197
        prevBucket = currBucket;
198
        insertData(index,entryInBucket,toBeWritten);
199
        }
200
    }
201

    
202
///////////////////////////////////////////////////////////////////////////////////////////////////
203

    
204
  private int retNumPositions(Tablebase table, int level)
205
    {
206
    int size=0, tableSize = table.getSize();
207

    
208
    for(int i=0; i<tableSize; i++)
209
      if( table.retrieveUnpacked(i)==level ) size++;
210

    
211
    return size;
212
    }
213

    
214
///////////////////////////////////////////////////////////////////////////////////////////////////
215

    
216
  private int computeApproximateSize(int tableSize, int positionSize, int numBits)
217
    {
218
    int entrySize = (1<<numBits);
219
    int numBuckets = (tableSize+entrySize-1)/entrySize;
220
    int overhang = (numBits==4 || numBits==12) ? numBuckets/4 : 0;
221
    int totalBytesForTable = (positionSize*numBits)/8 + overhang;
222
    int bucketBytes = numBytesForIndices(totalBytesForTable);
223
    int totalBytesForBuckets = numBuckets*bucketBytes;
224

    
225
    return HEADER_SIZE + totalBytesForBuckets + totalBytesForTable;
226
    }
227

    
228
///////////////////////////////////////////////////////////////////////////////////////////////////
229

    
230
  private void construct(Tablebase table, int level, int numBits)
231
    {
232
    mNumBits = numBits;
233
    mLevel = level;
234
    int size = table.getSize();
235
    int entrySize = (1<<numBits);
236
    mNumBuckets = (size+entrySize-1)/entrySize;
237
    int[] bucket = new int[mNumBuckets];
238

    
239
    mBucketOffsets = new int[mNumBuckets+1];
240

    
241
    for(int i=0; i<size; i++)
242
      if( table.retrieveUnpacked(i)==level )
243
        {
244
        int currBucket = i/entrySize;
245
        bucket[currBucket]++;
246
        }
247

    
248
    int totalBytesForTable = 0;
249

    
250
    for(int i=0; i<mNumBuckets; i++)
251
      {
252
      int numBytes = (bucket[i]*numBits+7)/8;
253
      bucket[i] = totalBytesForTable;
254
      totalBytesForTable+=numBytes;
255
      }
256

    
257
    mBucketBytes = numBytesForIndices(totalBytesForTable);
258
    int totalBytesForBuckets = mNumBuckets*mBucketBytes;
259
    int totalSize = HEADER_SIZE + totalBytesForBuckets + totalBytesForTable;
260
    mData = new byte[totalSize];
261

    
262
    android.util.Log.e("D", "Constructing a pruningTable, numBits="+numBits+" size "+totalSize);
263

    
264
    writeHeader();
265
    writeBucketPointers(bucket);
266
    writeBuckets(table);
267
    }
268

    
269
///////////////////////////////////////////////////////////////////////////////////////////////////
270
// PUBLIC API
271
///////////////////////////////////////////////////////////////////////////////////////////////////
272
// only numBits = 4,8,12,16,20,24 actually supported
273

    
274
  PruningTable(Tablebase table, int level, int numBits)
275
    {
276
    construct(table,level,numBits);
277
    }
278

    
279
///////////////////////////////////////////////////////////////////////////////////////////////////
280

    
281
  PruningTable(Tablebase table, int level)
282
    {
283
    int minimum = 0, chosen =  0;
284
    int size = retNumPositions(table,level);
285
    int tableSize = table.getSize();
286

    
287
    android.util.Log.e("D", "----- PRUNING TABLE LEVEL "+level+" ------");
288

    
289
    for(int i=0; i<SUPPORTED.length; i++)
290
      {
291
      int approx = computeApproximateSize( tableSize, size, SUPPORTED[i]);
292

    
293
      if( i==0 || approx<minimum )
294
        {
295
        minimum = approx;
296
        chosen  = i;
297
        }
298

    
299
      android.util.Log.e("D", "trying numBits: "+SUPPORTED[i]+" approx size: "+approx);
300
      }
301

    
302
    construct(table,level,SUPPORTED[chosen]);
303
    }
304

    
305
///////////////////////////////////////////////////////////////////////////////////////////////////
306

    
307
  PruningTable(byte[] data)
308
    {
309
    mData = data;
310

    
311
    mLevel       =  ((int)mData[0])&0xff;
312
    mBucketBytes =  ((int)mData[1])&0xff;
313
    mNumBits     =  ((int)mData[2])&0xff;
314
    mNumBuckets  =((((int)mData[3])&0xff)<<16) + ((((int)mData[4])&0xff)<<8) + (((int)mData[5])&0xff);
315

    
316
    mBucketOffsets = new int[mNumBuckets+1];
317
    }
318

    
319
///////////////////////////////////////////////////////////////////////////////////////////////////
320

    
321
  byte[] getPacked()
322
    {
323
    return mData;
324
    }
325

    
326
///////////////////////////////////////////////////////////////////////////////////////////////////
327

    
328
  int getLevel()
329
    {
330
    return mLevel;
331
    }
332

    
333
///////////////////////////////////////////////////////////////////////////////////////////////////
334

    
335
  boolean contains(int number)
336
    {
337
    int bucket = (number>>mNumBits);
338
    int offset1 = getBucketOffset(bucket);
339
    int offset2 = getBucketOffset(bucket+1);
340
    int numBytesInBucket = offset2-offset1;
341
    int numEntriesInBucket = (8*numBytesInBucket)/mNumBits;
342

    
343
    if( mNumBits==4 && numBytesInBucket>0 )
344
      {
345
      byte val = mData[offset2-1];
346
      if( (val/16)*16 == val ) numEntriesInBucket--;
347
      }
348

    
349
    int lower = 0;
350
    int upper = numEntriesInBucket-1;
351
    int rem = (bucket<<mNumBits);
352

    
353
    while( upper>=lower )
354
      {
355
      int mid = (lower+upper)/2;
356
      int value = retrieveData(offset1,mid) + rem;
357
      if( value==number ) return true;
358
      if( value>number ) upper = mid-1;
359
      else               lower = mid+1;
360
      }
361

    
362
    return false;
363
    }
364
}
(2-2/19)