Skip to content

Commit ecbada1

Browse files
committed
[Embedding] Support immutable EmbeddingVariable in inference mode.
1 parent 269b95d commit ecbada1

File tree

6 files changed

+84
-1
lines changed

6 files changed

+84
-1
lines changed

tensorflow/core/framework/embedding/embedding_filter.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,19 @@ class EmbeddingFilter {
5050
public:
5151
virtual void LookupOrCreate(K key, V* val, const V* default_value_ptr,
5252
ValuePtr<V>** value_ptr, int count, const V* default_value_no_permission) = 0;
53+
54+
virtual void Lookup(EV* ev, K key, V* val, const V* default_value_ptr,
55+
const V* default_value_no_permission) {
56+
ValuePtr<V>* value_ptr = nullptr;
57+
Status s = ev->LookupKey(key, &value_ptr);
58+
if (s.ok()) {
59+
V* mem_val = ev->LookupPrimaryEmb(value_ptr);
60+
memcpy(val, mem_val, sizeof(V) * ev->ValueLen());
61+
} else {
62+
memcpy(val, default_value_no_permission, sizeof(V) * ev->ValueLen());
63+
}
64+
}
65+
5366
virtual Status LookupOrCreateKey(K key, ValuePtr<V>** val, bool* is_filter) = 0;
5467
virtual void CreateGPUBatch(V* val_base, V** default_values, int64 size,
5568
int64 slice_elems, int64 value_len_, bool* init_flags, V** memcpy_address) = 0;

tensorflow/core/framework/embedding/embedding_var.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ class EmbeddingVar : public ResourceBase {
121121
return is_initialized_;
122122
}
123123

124+
Status LookupKey(K key, ValuePtr<V>** value_ptr) {
125+
return storage_manager_->Get(key, value_ptr);
126+
}
127+
124128
Status LookupOrCreateKey(K key, ValuePtr<V>** value_ptr, bool* is_filter) {
125129
return filter_->LookupOrCreateKey(key, value_ptr, is_filter);
126130
}
@@ -159,6 +163,11 @@ class EmbeddingVar : public ResourceBase {
159163
return filter_->GetFreq(key);
160164
}
161165

166+
void Lookup(K key, V* val, V* default_v) {
167+
const V* default_value_ptr = (default_v == nullptr) ? default_value_ : default_v;
168+
filter_->Lookup(this, key, val, default_value_ptr, default_value_no_permission_);
169+
}
170+
162171
void LookupOrCreate(K key, V* val, V* default_v, int count = 1) {
163172
const V* default_value_ptr = (default_v == nullptr) ? default_value_ : default_v;
164173
ValuePtr<V>* value_ptr = nullptr;

tensorflow/core/framework/embedding/multilevel_embedding.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,18 @@ class StorageManager {
243243
}
244244
}
245245

246+
Status Get(K key, ValuePtr<V>** value_ptr) {
247+
Status s;
248+
int level = 0;
249+
for (; level < hash_table_count_; ++level) {
250+
s = kvs_[level].first->Lookup(key, value_ptr);
251+
if (s.ok()) {
252+
break;
253+
}
254+
}
255+
return s;
256+
}
257+
246258
Status GetOrCreate(K key, ValuePtr<V>** value_ptr, size_t size) {
247259
bool found = false;
248260
int level = 0;

tensorflow/core/kernels/kv_variable_ops.cc

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ namespace tensorflow {
5050
namespace {
5151
const int64 kEmbeddingVarUseDB = -214;
5252
const int64 kInitializableEmbeddingVarUseDB = -215;
53+
const char* kInferenceMode = "INFERENCE_MODE";
5354
}
5455

5556
#define REGISTER_KV_VAR_HANDLE(ktype, vtype) \
@@ -438,6 +439,10 @@ template <typename TKey, typename TValue>
438439
class KvResourceGatherOp : public OpKernel {
439440
public:
440441
explicit KvResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) {
442+
OP_REQUIRES_OK(c, c->GetAttr("is_inference", &is_inference_));
443+
bool is_inference;
444+
TF_CHECK_OK(ReadBoolFromEnvVar(kInferenceMode, false, &is_inference));
445+
is_inference_ |= is_inference;
441446
OP_REQUIRES_OK(c,
442447
c->GetAttr("is_use_default_value_tensor",
443448
&is_use_default_value_tensor_));
@@ -461,6 +466,17 @@ class KvResourceGatherOp : public OpKernel {
461466
return 1;
462467
};
463468
}
469+
if (!is_inference_) {
470+
lookup_fn_ = [](EmbeddingVar<TKey, TValue>* ev, TKey key,
471+
TValue* val, TValue* default_v, int count) {
472+
ev->LookupOrCreate(key, val, default_v, count);
473+
};
474+
} else {
475+
lookup_fn_ = [](EmbeddingVar<TKey, TValue>* ev, TKey key,
476+
TValue* val, TValue* default_v, int count) {
477+
ev->Lookup(key, val, default_v);
478+
};
479+
}
464480
}
465481

466482
void Compute(OpKernelContext* c) override {
@@ -511,7 +527,7 @@ class KvResourceGatherOp : public OpKernel {
511527
default_v, indices_flat(i), i, ev->GetDefaultValueDim(),
512528
ev->ValueLen());
513529
int32 count = get_count_fn_(counts, i);
514-
ev->LookupOrCreate(indices_flat(i),
530+
lookup_fn_(ev, indices_flat(i),
515531
out_base + i * slice_elems, default_v_ptr, count);
516532
}
517533
};
@@ -530,9 +546,12 @@ class KvResourceGatherOp : public OpKernel {
530546

531547
private:
532548
bool is_use_default_value_tensor_;
549+
bool is_inference_;
533550
std::function<
534551
TValue*(TValue*, TKey, int64, int64, int64)> get_default_v_fn_;
535552
std::function<int32(int32*, int64)> get_count_fn_;
553+
std::function<void(EmbeddingVar<TKey, TValue>* ev,
554+
TKey key, TValue* val, TValue* default_v, int count)> lookup_fn_;
536555
};
537556

538557
#define REGISTER_GATHER_FULL(dev, ktype, vtype) \

tensorflow/core/ops/kv_variable_ops.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ REGISTER_OP("KvResourceGatherV1")
234234
.Input("counts: counts_type")
235235
.Attr("validate_indices: bool = true")
236236
.Attr("is_use_default_value_tensor: bool = false")
237+
.Attr("is_inference: bool = false")
237238
.Output("output: dtype")
238239
.Attr("dtype: type")
239240
.Attr("Tkeys: {int64,int32,string}")
@@ -284,6 +285,7 @@ REGISTER_OP("KvResourceGather")
284285
.Output("output: dtype")
285286
.Attr("dtype: type")
286287
.Attr("Tkeys: {int64,int32,string}")
288+
.Attr("is_inference: bool = false")
287289
.SetShapeFn([](InferenceContext* c) {
288290
ShapeAndType handle_shape_and_type;
289291
TF_RETURN_IF_ERROR(

tensorflow/python/ops/embedding_variable_ops_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from six.moves import xrange # pylint: disable=redefined-builtin
1818

19+
from tensorflow.core.framework import attr_value_pb2
1920
from tensorflow.python.framework import ops
2021
from tensorflow.python.framework import test_util
2122
from tensorflow.python.ops import string_ops
@@ -2236,6 +2237,33 @@ def testEmbeddingVariableForGetFrequencyAndVersion(self):
22362237
self.assertAllEqual(np.array([3,1,2,0,2,0,1]), f)
22372238
self.assertAllEqual(np.array([2,0,1,0,2,0,2]), v)
22382239

2240+
def testEmbeddingVariableForInference(self):
2241+
print("testEmbeddingVariableForInference")
2242+
var = variable_scope.get_embedding_variable("var_1",
2243+
embedding_dim = 3,
2244+
initializer=init_ops.ones_initializer(dtypes.float32),
2245+
ev_option = variables.EmbeddingVariableOption(
2246+
filter_option=variables.CounterFilter(filter_freq=3),
2247+
evict_option=variables.GlobalStepEvict(steps_to_live=2))
2248+
)
2249+
shape=var.get_dynamic_shape()
2250+
ids = array_ops.placeholder(dtype=dtypes.int64, name='ids')
2251+
emb = embedding_ops.embedding_lookup(var, ids)
2252+
# modify graph for infer
2253+
# emb.op.inputs[0].op.inputs[0].op._set_attr("is_inference", attr_value_pb2.AttrValue(b=True))
2254+
# set environment
2255+
os.environ["INFERENCE_MODE"] = "1"
2256+
fun = math_ops.multiply(emb, 2.0, name='multiply')
2257+
loss = math_ops.reduce_sum(fun, name='reduce_sum')
2258+
init = variables.global_variables_initializer()
2259+
with self.test_session() as sess:
2260+
sess.run([init])
2261+
sess.run([emb, loss], feed_dict={'ids:0': [1,2,3]})
2262+
sess.run([emb, loss], feed_dict={'ids:0': [1,3,5]})
2263+
sess.run([emb, loss], feed_dict={'ids:0': [1,5,7]})
2264+
s = sess.run(shape)
2265+
self.assertAllEqual(np.array([0,3]), s)
2266+
22392267
'''
22402268
@test_util.run_gpu_only
22412269
def testEmbeddingVariableForHBMandDRAM(self):

0 commit comments

Comments
 (0)