Table Batched Embedding (TBE) Operators¶
- fbgemm_gpu.split_table_batched_embeddings_ops.SplitTableBatchedEmbeddingBagsCodegen(embedding_specs, feature_table_map=None, cache_algorithm=CacheAlgorithm.LRU, cache_load_factor=0.2, cache_sets=0, cache_reserved_memory=0.0, cache_precision=SparseType.FP32, weights_precision=SparseType.FP32, output_dtype=SparseType.FP32, enforce_hbm=False, optimizer=OptimType.EXACT_SGD, record_cache_metrics=None, stochastic_rounding=True, gradient_clipping=False, max_gradient=1.0, learning_rate=0.01, eps=1.0e-8, momentum=0.9, weight_decay=0.0, weight_decay_mode=WeightDecayMode.NONE, eta=0.001, beta1=0.9, beta2=0.999, pooling_mode=PoolingMode.SUM, device=None, bounds_check_mode=BoundsCheckMode.WARNING) None [source]¶
Table Batched Embedding (TBE) operator. Looks up one or more embedding tables. The module is application for training. The backward operator is fused with optimizer. Thus, the embedding tables are updated during backward.
- Parameters:
embedding_specs (List[Tuple[int, int, EmbeddingLocation, ComputeDevice]]) – A list of embedding specifications. Each spec is a tuple of (number of embedding rows, embedding dimension; must be a multiple of 4, table placement, compute device).
feature_table_map (List[int], optional) – An optional list that specifies feature-table mapping.
cache_algorithm (CacheAlgorithm, optional) – LXU cache algorithm (CacheAlgorithm.LRU, CacheAlgorithm.LFU)
cache_load_factor (float, optional) – The LXU cache capacity which is cache_load_factor * the total number of rows in all embedding tables
cache_sets (int, optional) – The number of cache sets
cache_reserved_memory (float, optional) – Amount of memory reserved in HBM for non-cache purpose.
cache_precision (SparseType, optional) – Data type of LXU cache (SparseType.FP32, SparseType.FP16)
weights_precision (SparseType, optional) – Data type of embedding tables (also known as weights) (SparseType.FP32, SparseType.FP16, SparseType.INT8)
output_dtype (SparseType, optional) – Data type of an output tensor (SparseType.FP32, SparseType.FP16, SparseType.INT8)
enforce_hbm (bool, optional) – If True, place all weights/momentums in HBM when using cache
optimizer (OptimType, optional) – An optimizer to use for embedding table update in the backward pass. (OptimType.ADAM, OptimType.EXACT_ADAGRAD, OptimType.EXACT_ROWWISE_ADAGRAD, OptimType.EXACT_SGD, OptimType.LAMB, OptimType.LARS_SGD, OptimType.PARTIAL_ROWWISE_ADAM, OptimType.PARTIAL_ROWWISE_LAMB, OptimType.SGD)
record_cache_metrics (RecordCacheMetrics, optional) – Record number of hits, number of requests, etc if RecordCacheMetrics.record_cache_miss_counter is True and record the similar metrics table-wise if RecordCacheMetrics.record_tablewise_cache_miss is True (default is None).
stochastic_rounding (bool, optional) – If True, apply stochastic rounding for weight type that is not SparseType.FP32
gradient_clipping (bool, optional) – If True, apply gradient clipping
max_gradient (float, optional) – The value for gradient clipping
learning_rate (float, optional) – The learning rate
eps (float, optional) – The epsilon value used by Adagrad, LAMB, and Adam
momentum (float, optional) – Momentum used by LARS-SGD
weight_decay (float, optional) – Weight decay used by LARS-SGD, LAMB, ADAM, and Rowwise Adagrad
weight_decay_mode (WeightDecayMode, optional) – Weight decay mode (WeightDecayMode.NONE, WeightDecayMode.L2, WeightDecayMode.DECOUPLE)
eta (float, optional) – The eta value used by LARS-SGD
beta1 (float, optional) – The beta1 value used by LAMB and ADAM
beta2 (float, optional) – The beta2 value used by LAMB and ADAM
pooling_mode (PoolingMode, optional) – Pooling mode (PoolingMode.SUM, PoolingMode.MEAN, PoolingMode.NONE)
device (torch.device, optional) – The current device to place tensors on
bounds_check_mode (BoundsCheckMode, optional) – If not set to BoundsCheckMode.NONE, apply boundary check for indices (BoundsCheckMode.NONE, BoundsCheckMode.FATAL, BoundsCheckMode.WARNING, BoundsCheckMode.IGNORE)
- Inputs:
indices (torch.Tensor): A 1D-tensor that contains indices to be accessed in all embedding table
offsets (torch.Tensor): A 1D-tensor that conatins offsets of indices. Shape (B * T + 1) where B = batch size and T = number of tables. offsets[t * B + b + 1] - offsets[t * B + b] is the length of bag b of table t
per_sample_weights (torch.Tensor, optional): An optional 1D-tensor that contains positional weights. Shape (max(bag length)). Positional weight i is multiplied to all columns of row i in each bag after its read from the embedding table and before pooling (if pooling mode is not PoolingMode.NONE).
feature_requires_grad (torch.Tensor, optional): An optional tensor for checking if per_sample_weights requires gradient
- Returns:
A 2D-tensor containing looked up data. Shape (B, total_D) where B = batch size and total_D = the sum of all embedding dimensions in the table
Example
>>> import torch >>> >>> from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( >>> EmbeddingLocation, >>> ) >>> from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( >>> SplitTableBatchedEmbeddingBagsCodegen, >>> ComputeDevice, >>> ) >>> >>> # Two tables >>> embedding_specs = [ >>> (3, 8, EmbeddingLocation.DEVICE, ComputeDevice.CUDA), >>> (5, 4, EmbeddingLocation.MANAGED, ComputeDevice.CUDA) >>> ] >>> >>> tbe = SplitTableBatchedEmbeddingBagsCodegen(embedding_specs) >>> tbe.init_embedding_weights_uniform(-1, 1) >>> >>> print(tbe.split_embedding_weights()) [tensor([[-0.9426, 0.7046, 0.4214, -0.0419, 0.1331, -0.7856, -0.8124, -0.2021], [-0.5771, 0.5911, -0.7792, -0.1068, -0.6203, 0.4813, -0.1677, 0.4790], [-0.5587, -0.0941, 0.5754, 0.3475, -0.8952, -0.1964, 0.0810, -0.4174]], device='cuda:0'), tensor([[-0.2513, -0.4039, -0.3775, 0.3273], [-0.5399, -0.0229, -0.1455, -0.8770], [-0.9520, 0.4593, -0.7169, 0.6307], [-0.1765, 0.8757, 0.8614, 0.2051], [-0.0603, -0.9980, -0.7958, -0.5826]], device='cuda:0')]
>>> # Batch size = 3 >>> indices = torch.tensor([0, 1, 2, 0, 1, 2, 0, 3, 1, 4, 2, 0, 0], >>> device="cuda", >>> dtype=torch.long) >>> offsets = torch.tensor([0, 2, 5, 7, 9, 12, 13], >>> device="cuda", >>> dtype=torch.long) >>> >>> output = tbe(indices, offsets) >>> >>> # Batch size = 3, total embedding dimension = 12 >>> print(output.shape) torch.Size([3, 12])
>>> print(output) tensor([[-1.5197, 1.2957, -0.3578, -0.1487, -0.4873, -0.3044, -0.9801, 0.2769, -0.7164, 0.8528, 0.7159, -0.6719], [-2.0784, 1.2016, 0.2176, 0.1988, -1.3825, -0.5008, -0.8991, -0.1405, -1.2637, -0.9427, -1.8902, 0.3754], [-1.5013, 0.6105, 0.9968, 0.3057, -0.7621, -0.9821, -0.7314, -0.6195, -0.2513, -0.4039, -0.3775, 0.3273]], device='cuda:0', grad_fn=<CppNode<SplitLookupFunction_sgd_Op>>)