diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc index 5e48ae97667..6080f32072e 100644 --- a/tensorflow/core/kernels/decode_csv_op.cc +++ b/tensorflow/core/kernels/decode_csv_op.cc @@ -137,6 +137,25 @@ class DecodeCSVOp : public OpKernel { } break; } + case DT_DOUBLE: { + // If this field is empty or NA value, check if default is given: + // If yes, use default value; Otherwise report error. + if (fields[f].empty() || fields[f] == na_value_) { + OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1, + errors::InvalidArgument( + "Field ", f, + " is required but missing in record ", i, "!")); + output[f]->flat()(i) = record_defaults[f].flat()(0); + } else { + double value; + OP_REQUIRES(ctx, strings::safe_strtod(fields[f].c_str(), &value), + errors::InvalidArgument( + "Field ", f, " in record ", i, + " is not a valid double: ", fields[f])); + output[f]->flat()(i) = value; + } + break; + } case DT_STRING: { // If this field is empty or NA value, check if default is given: // If yes, use default value; Otherwise report error. diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc index b44ea2e080e..40ec792ef82 100644 --- a/tensorflow/core/ops/parsing_ops.cc +++ b/tensorflow/core/ops/parsing_ops.cc @@ -329,7 +329,7 @@ REGISTER_OP("DecodeCSV") .Input("records: string") .Input("record_defaults: OUT_TYPE") .Output("output: OUT_TYPE") - .Attr("OUT_TYPE: list({float,int32,int64,string})") + .Attr("OUT_TYPE: list({float,double,int32,int64,string})") .Attr("field_delim: string = ','") .Attr("use_quote_delim: bool = true") .Attr("na_value: string = ''")